|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import plotly.graph_objects as go |
|
|
from plotly.subplots import make_subplots |
|
|
|
|
|
|
|
|
def load_data(): |
|
|
"""Load the dataset from a local CSV file""" |
|
|
df = pd.read_csv("EEG_Eye_State.csv") |
|
|
return df |
|
|
|
|
|
|
|
|
df = load_data() |
|
|
|
|
|
|
|
|
eeg_channels = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', |
|
|
'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4'] |
|
|
|
|
|
def plot_eeg_signals(start_time, duration, eye_state_filter, selected_channels): |
|
|
""" |
|
|
Visualize the selected EEG signals |
|
|
""" |
|
|
|
|
|
sampling_rate = 128 |
|
|
start_idx = int(start_time * sampling_rate) |
|
|
end_idx = start_idx + int(duration * sampling_rate) |
|
|
|
|
|
|
|
|
df_segment = df.iloc[start_idx:end_idx].copy() |
|
|
|
|
|
|
|
|
if eye_state_filter != "Both": |
|
|
filter_value = 1 if eye_state_filter == "Closed" else 0 |
|
|
df_segment = df_segment[df_segment['eyeDetection'] == filter_value] |
|
|
|
|
|
if len(df_segment) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
n_channels = len(selected_channels) |
|
|
fig = make_subplots( |
|
|
rows=n_channels, |
|
|
cols=1, |
|
|
shared_xaxes=True, |
|
|
vertical_spacing=0.02, |
|
|
subplot_titles=selected_channels |
|
|
) |
|
|
|
|
|
|
|
|
time_axis = np.arange(len(df_segment)) / sampling_rate + start_time |
|
|
|
|
|
|
|
|
for idx, channel in enumerate(selected_channels, 1): |
|
|
|
|
|
colors = ['red' if x == 1 else 'blue' for x in df_segment['eyeDetection']] |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=time_axis, |
|
|
y=df_segment[channel], |
|
|
mode='lines', |
|
|
name=channel, |
|
|
line=dict(color='steelblue', width=1), |
|
|
showlegend=False |
|
|
), |
|
|
row=idx, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
eye_closed_mask = df_segment['eyeDetection'] == 1 |
|
|
if eye_closed_mask.any(): |
|
|
closed_indices = np.where(eye_closed_mask)[0] |
|
|
|
|
|
if len(closed_indices) > 0: |
|
|
groups = np.split(closed_indices, np.where(np.diff(closed_indices) != 1)[0] + 1) |
|
|
for group in groups: |
|
|
if len(group) > 0: |
|
|
fig.add_vrect( |
|
|
x0=time_axis[group[0]], |
|
|
x1=time_axis[group[-1]], |
|
|
fillcolor="red", opacity=0.1, |
|
|
layer="below", line_width=0, |
|
|
row=idx, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_xaxes(title_text="Time (seconds)", row=n_channels, col=1) |
|
|
fig.update_yaxes(title_text="Amplitude (μV)") |
|
|
|
|
|
fig.update_layout( |
|
|
height=200 * n_channels, |
|
|
title_text=f"EEG Signals - {eye_state_filter} Eyes", |
|
|
showlegend=False, |
|
|
hovermode='x unified' |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
def plot_channel_comparison(channels, eye_state_filter, remove_outliers): |
|
|
""" |
|
|
Compare specific channels between open and closed eyes |
|
|
""" |
|
|
if not channels: |
|
|
return None |
|
|
|
|
|
n_channels = len(channels) |
|
|
|
|
|
|
|
|
n_cols = 2 if eye_state_filter == "Both" else 1 |
|
|
|
|
|
if eye_state_filter == "Both": |
|
|
subplot_titles = [f'{ch} - Eyes Open' if i % 2 == 0 else f'{ch} - Eyes Closed' |
|
|
for ch in channels for i in range(2)] |
|
|
specs = [[{'type': 'box'}, {'type': 'histogram'}] for _ in range(n_channels)] |
|
|
else: |
|
|
state_label = "Eyes Open" if eye_state_filter == "Open" else "Eyes Closed" |
|
|
subplot_titles = [f'{ch} - {state_label}' for ch in channels] |
|
|
specs = [[{'type': 'box'}] for _ in range(n_channels)] |
|
|
|
|
|
fig = make_subplots( |
|
|
rows=n_channels, cols=n_cols, |
|
|
subplot_titles=subplot_titles, |
|
|
specs=specs, |
|
|
vertical_spacing=0.08 |
|
|
) |
|
|
|
|
|
for idx, channel in enumerate(channels, 1): |
|
|
df_open = df[df['eyeDetection'] == 0][channel] |
|
|
df_closed = df[df['eyeDetection'] == 1][channel] |
|
|
|
|
|
|
|
|
if remove_outliers: |
|
|
def filter_outliers(data): |
|
|
Q1 = data.quantile(0.25) |
|
|
Q3 = data.quantile(0.75) |
|
|
IQR = Q3 - Q1 |
|
|
lower_bound = Q1 - 1.5 * IQR |
|
|
upper_bound = Q3 + 1.5 * IQR |
|
|
return data[(data >= lower_bound) & (data <= upper_bound)] |
|
|
|
|
|
df_open = filter_outliers(df_open) |
|
|
df_closed = filter_outliers(df_closed) |
|
|
|
|
|
if eye_state_filter in ["Both", "Open"]: |
|
|
|
|
|
fig.add_trace( |
|
|
go.Box(y=df_open, name=f'{channel} Open', marker_color='blue', |
|
|
showlegend=(idx==1)), |
|
|
row=idx, col=1 |
|
|
) |
|
|
|
|
|
if eye_state_filter in ["Both", "Closed"]: |
|
|
|
|
|
fig.add_trace( |
|
|
go.Box(y=df_closed, name=f'{channel} Closed', marker_color='red', |
|
|
showlegend=(idx==1)), |
|
|
row=idx, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
if eye_state_filter == "Both": |
|
|
|
|
|
fig.add_trace( |
|
|
go.Histogram(x=df_open, name=f'{channel} Open', marker_color='blue', |
|
|
opacity=0.7, showlegend=False, nbinsx=30), |
|
|
row=idx, col=2 |
|
|
) |
|
|
|
|
|
fig.add_trace( |
|
|
go.Histogram(x=df_closed, name=f'{channel} Closed', marker_color='red', |
|
|
opacity=0.7, showlegend=False, nbinsx=30), |
|
|
row=idx, col=2 |
|
|
) |
|
|
|
|
|
|
|
|
all_data = pd.concat([df_open, df_closed]) |
|
|
data_min = all_data.min() |
|
|
data_max = all_data.max() |
|
|
data_range = data_max - data_min |
|
|
margin = data_range * 0.1 |
|
|
|
|
|
fig.update_xaxes( |
|
|
range=[data_min - margin, data_max + margin], |
|
|
row=idx, col=2 |
|
|
) |
|
|
|
|
|
fig.update_layout( |
|
|
height=350 * n_channels, |
|
|
title_text=f"Channel Distribution Comparison - {eye_state_filter} Eyes", |
|
|
showlegend=True |
|
|
) |
|
|
|
|
|
if eye_state_filter == "Both": |
|
|
fig.update_xaxes(title_text="Amplitude (μV)", row=n_channels, col=2) |
|
|
fig.update_yaxes(title_text="Amplitude (μV)") |
|
|
|
|
|
return fig |
|
|
|
|
|
def get_statistics(): |
|
|
""" |
|
|
Generate dataset statistics in text format |
|
|
""" |
|
|
stats = [] |
|
|
|
|
|
|
|
|
total_samples = len(df) |
|
|
eyes_open = len(df[df['eyeDetection'] == 0]) |
|
|
eyes_closed = len(df[df['eyeDetection'] == 1]) |
|
|
duration = total_samples / 128 |
|
|
|
|
|
stats.append(f"**Dataset Statistics**") |
|
|
stats.append(f"- Total samples: {total_samples:,}") |
|
|
stats.append(f"- Duration: {duration:.2f} seconds") |
|
|
stats.append(f"- Sampling rate: 128 Hz") |
|
|
stats.append(f"- Eyes Open samples: {eyes_open:,} ({eyes_open/total_samples*100:.1f}%)") |
|
|
stats.append(f"- Eyes Closed samples: {eyes_closed:,} ({eyes_closed/total_samples*100:.1f}%)") |
|
|
|
|
|
return "\n".join(stats) |
|
|
|
|
|
def get_statistics_table(): |
|
|
""" |
|
|
Generate statistics table per channel |
|
|
""" |
|
|
stats_data = [] |
|
|
|
|
|
for channel in eeg_channels: |
|
|
channel_data = df[channel] |
|
|
open_data = df[df['eyeDetection'] == 0][channel] |
|
|
closed_data = df[df['eyeDetection'] == 1][channel] |
|
|
|
|
|
stats_data.append({ |
|
|
'Channel': channel, |
|
|
'Mean (All)': f"{channel_data.mean():.2f}", |
|
|
'Std (All)': f"{channel_data.std():.2f}", |
|
|
'Mean (Open)': f"{open_data.mean():.2f}", |
|
|
'Mean (Closed)': f"{closed_data.mean():.2f}", |
|
|
'Min': f"{channel_data.min():.2f}", |
|
|
'Max': f"{channel_data.max():.2f}" |
|
|
}) |
|
|
|
|
|
return pd.DataFrame(stats_data) |
|
|
|
|
|
def plot_correlation_matrix(): |
|
|
""" |
|
|
Visualize the correlation matrix between channels |
|
|
""" |
|
|
corr_matrix = df[eeg_channels].corr() |
|
|
|
|
|
fig = go.Figure(data=go.Heatmap( |
|
|
z=corr_matrix.values, |
|
|
x=eeg_channels, |
|
|
y=eeg_channels, |
|
|
colorscale='RdBu', |
|
|
zmid=0, |
|
|
text=corr_matrix.values, |
|
|
texttemplate='%{text:.2f}', |
|
|
textfont={"size": 9}, |
|
|
colorbar=dict(title="Correlation") |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
title={ |
|
|
'text': "EEG Channels Correlation Matrix", |
|
|
'x': 0.5, |
|
|
'xanchor': 'center' |
|
|
}, |
|
|
height=600, |
|
|
width=1215, |
|
|
xaxis={'side': 'bottom'} |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
demo = gr.Blocks(title="EEG Eye State Visualizer") |
|
|
|
|
|
with demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# 🧠 EEG Eye State Visualizer |
|
|
|
|
|
Explore and visualize the EEG Eye State Classification Dataset. This interactive tool allows you to: |
|
|
- View EEG signals from 14 channels |
|
|
- Compare patterns between open and closed eyes |
|
|
- Analyze statistical distributions |
|
|
- Examine channel correlations |
|
|
|
|
|
**Dataset Info**: 14,980 samples | 128 Hz sampling rate | 14 EEG channels |
|
|
""") |
|
|
|
|
|
with gr.Tab("Signal Viewer"): |
|
|
gr.Markdown("### Visualize EEG Signals") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
start_time = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=117, |
|
|
value=0, |
|
|
step=0.5, |
|
|
label="Start Time (seconds)" |
|
|
) |
|
|
duration = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=5, |
|
|
step=0.5, |
|
|
label="Duration (seconds)" |
|
|
) |
|
|
eye_state = gr.Radio( |
|
|
choices=["Both", "Open", "Closed"], |
|
|
value="Both", |
|
|
label="Eye State Filter" |
|
|
) |
|
|
channels = gr.CheckboxGroup( |
|
|
choices=eeg_channels, |
|
|
value=['AF3', 'F7', 'O1', 'O2'], |
|
|
label="Select Channels to Display" |
|
|
) |
|
|
plot_btn = gr.Button("Generate Plot", variant="primary") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
signal_plot = gr.Plot(label="EEG Signals") |
|
|
|
|
|
plot_btn.click( |
|
|
fn=plot_eeg_signals, |
|
|
inputs=[start_time, duration, eye_state, channels], |
|
|
outputs=signal_plot |
|
|
) |
|
|
|
|
|
with gr.Tab("Channel Analysis"): |
|
|
gr.Markdown("### Compare Multiple Channels") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
channels_select = gr.CheckboxGroup( |
|
|
choices=eeg_channels, |
|
|
value=['AF3', 'O1'], |
|
|
label="Select Channels to Compare" |
|
|
) |
|
|
eye_state_compare = gr.Radio( |
|
|
choices=["Both", "Open", "Closed"], |
|
|
value="Both", |
|
|
label="Eye State Filter" |
|
|
) |
|
|
remove_outliers_check = gr.Checkbox( |
|
|
label="Remove Outliers (IQR method)", |
|
|
value=False |
|
|
) |
|
|
compare_btn = gr.Button("Analyze Channels", variant="primary") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
comparison_plot = gr.Plot(label="Channel Comparison") |
|
|
|
|
|
compare_btn.click( |
|
|
fn=plot_channel_comparison, |
|
|
inputs=[channels_select, eye_state_compare, remove_outliers_check], |
|
|
outputs=comparison_plot |
|
|
) |
|
|
|
|
|
with gr.Tab("Statistics"): |
|
|
gr.Markdown("### Dataset Statistics") |
|
|
|
|
|
stats_text = gr.Markdown(value=get_statistics()) |
|
|
|
|
|
gr.Markdown("### Channel Statistics Table (μV)") |
|
|
stats_table = gr.Dataframe( |
|
|
value=get_statistics_table(), |
|
|
interactive=False, |
|
|
wrap=True |
|
|
) |
|
|
|
|
|
gr.Markdown("### Correlation Matrix") |
|
|
with gr.Row(): |
|
|
corr_plot = gr.Plot( |
|
|
value=plot_correlation_matrix(), |
|
|
container=True, |
|
|
scale=1 |
|
|
) |
|
|
|
|
|
with gr.Tab("About"): |
|
|
gr.Markdown(""" |
|
|
## About this Dataset |
|
|
|
|
|
The EEG Eye State Classification Dataset contains continuous EEG measurements from 14 electrodes |
|
|
collected during different eye states (open/closed). |
|
|
|
|
|
### Key Features: |
|
|
- **Total Instances**: 14,980 observations |
|
|
- **Features**: 14 EEG channel measurements |
|
|
- **Sampling Rate**: 128 Hz |
|
|
- **Duration**: ~117 seconds |
|
|
- **Device**: Emotiv EEG Neuroheadset |
|
|
|
|
|
### Electrode Placement: |
|
|
The 14 channels follow the international 10-20 system: |
|
|
- Left hemisphere: AF3, F7, F3, FC5, T7, P7, O1 |
|
|
- Right hemisphere: O2, P8, T8, FC6, F4, F8, AF4 |
|
|
|
|
|
### Citation: |
|
|
``` |
|
|
Rösler, O. (2013). EEG Eye State. |
|
|
UCI Machine Learning Repository. |
|
|
https://doi.org/10.24432/C57G7J |
|
|
``` |
|
|
|
|
|
### Links: |
|
|
- [Dataset on Hugging Face](https://huggingface.co/datasets/BrainSpectralAnalytics/eeg-eye-state-classification) |
|
|
- [Original UCI Repository](https://archive.ics.uci.edu/dataset/264/eeg+eye+state) |
|
|
- [Kaggle Example](https://www.kaggle.com/code/beta3logic/eye-state-eeg-classification-model-using-automl) |
|
|
""") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(ssr_mode=False) |