beta3's picture
Create app.py
feaf2ab verified
import gradio as gr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Load data
def load_data():
"""Load the dataset from a local CSV file"""
df = pd.read_csv("EEG_Eye_State.csv")
return df
# Initialize data
df = load_data()
# List of EEG channels
eeg_channels = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1',
'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
def plot_eeg_signals(start_time, duration, eye_state_filter, selected_channels):
"""
Visualize the selected EEG signals
"""
# Calculate indices based on time (128 Hz)
sampling_rate = 128
start_idx = int(start_time * sampling_rate)
end_idx = start_idx + int(duration * sampling_rate)
# Filter data segment
df_segment = df.iloc[start_idx:end_idx].copy()
# Filter by eye state if selected
if eye_state_filter != "Both":
filter_value = 1 if eye_state_filter == "Closed" else 0
df_segment = df_segment[df_segment['eyeDetection'] == filter_value]
if len(df_segment) == 0:
return None
# Create subplots
n_channels = len(selected_channels)
fig = make_subplots(
rows=n_channels,
cols=1,
shared_xaxes=True,
vertical_spacing=0.02,
subplot_titles=selected_channels
)
# Create time axis
time_axis = np.arange(len(df_segment)) / sampling_rate + start_time
# Add each channell
for idx, channel in enumerate(selected_channels, 1):
# Color based on eye state
colors = ['red' if x == 1 else 'blue' for x in df_segment['eyeDetection']]
fig.add_trace(
go.Scatter(
x=time_axis,
y=df_segment[channel],
mode='lines',
name=channel,
line=dict(color='steelblue', width=1),
showlegend=False
),
row=idx, col=1
)
# Add shaded areas for closed eyes
eye_closed_mask = df_segment['eyeDetection'] == 1
if eye_closed_mask.any():
closed_indices = np.where(eye_closed_mask)[0]
# Group consecutive indices
if len(closed_indices) > 0:
groups = np.split(closed_indices, np.where(np.diff(closed_indices) != 1)[0] + 1)
for group in groups:
if len(group) > 0:
fig.add_vrect(
x0=time_axis[group[0]],
x1=time_axis[group[-1]],
fillcolor="red", opacity=0.1,
layer="below", line_width=0,
row=idx, col=1
)
# Update layout
fig.update_xaxes(title_text="Time (seconds)", row=n_channels, col=1)
fig.update_yaxes(title_text="Amplitude (μV)")
fig.update_layout(
height=200 * n_channels,
title_text=f"EEG Signals - {eye_state_filter} Eyes",
showlegend=False,
hovermode='x unified'
)
return fig
def plot_channel_comparison(channels, eye_state_filter, remove_outliers):
"""
Compare specific channels between open and closed eyes
"""
if not channels:
return None
n_channels = len(channels)
# Determine number of columns based on filter
n_cols = 2 if eye_state_filter == "Both" else 1
if eye_state_filter == "Both":
subplot_titles = [f'{ch} - Eyes Open' if i % 2 == 0 else f'{ch} - Eyes Closed'
for ch in channels for i in range(2)]
specs = [[{'type': 'box'}, {'type': 'histogram'}] for _ in range(n_channels)]
else:
state_label = "Eyes Open" if eye_state_filter == "Open" else "Eyes Closed"
subplot_titles = [f'{ch} - {state_label}' for ch in channels]
specs = [[{'type': 'box'}] for _ in range(n_channels)]
fig = make_subplots(
rows=n_channels, cols=n_cols,
subplot_titles=subplot_titles,
specs=specs,
vertical_spacing=0.08
)
for idx, channel in enumerate(channels, 1):
df_open = df[df['eyeDetection'] == 0][channel]
df_closed = df[df['eyeDetection'] == 1][channel]
# Filter outliers if requested
if remove_outliers:
def filter_outliers(data):
Q1 = data.quantile(0.25)
Q3 = data.quantile(0.75)
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
return data[(data >= lower_bound) & (data <= upper_bound)]
df_open = filter_outliers(df_open)
df_closed = filter_outliers(df_closed)
if eye_state_filter in ["Both", "Open"]:
# Boxplot for Open
fig.add_trace(
go.Box(y=df_open, name=f'{channel} Open', marker_color='blue',
showlegend=(idx==1)),
row=idx, col=1
)
if eye_state_filter in ["Both", "Closed"]:
# Boxplot for Closed
fig.add_trace(
go.Box(y=df_closed, name=f'{channel} Closed', marker_color='red',
showlegend=(idx==1)),
row=idx, col=1
)
# Histogram only if "Both"
if eye_state_filter == "Both":
# Histograma Open
fig.add_trace(
go.Histogram(x=df_open, name=f'{channel} Open', marker_color='blue',
opacity=0.7, showlegend=False, nbinsx=30),
row=idx, col=2
)
# Histogram Closed
fig.add_trace(
go.Histogram(x=df_closed, name=f'{channel} Closed', marker_color='red',
opacity=0.7, showlegend=False, nbinsx=30),
row=idx, col=2
)
# Center and adjust histogram axes
all_data = pd.concat([df_open, df_closed])
data_min = all_data.min()
data_max = all_data.max()
data_range = data_max - data_min
margin = data_range * 0.1
fig.update_xaxes(
range=[data_min - margin, data_max + margin],
row=idx, col=2
)
fig.update_layout(
height=350 * n_channels,
title_text=f"Channel Distribution Comparison - {eye_state_filter} Eyes",
showlegend=True
)
if eye_state_filter == "Both":
fig.update_xaxes(title_text="Amplitude (μV)", row=n_channels, col=2)
fig.update_yaxes(title_text="Amplitude (μV)")
return fig
def get_statistics():
"""
Generate dataset statistics in text format
"""
stats = []
# General information
total_samples = len(df)
eyes_open = len(df[df['eyeDetection'] == 0])
eyes_closed = len(df[df['eyeDetection'] == 1])
duration = total_samples / 128 # seconds
stats.append(f"**Dataset Statistics**")
stats.append(f"- Total samples: {total_samples:,}")
stats.append(f"- Duration: {duration:.2f} seconds")
stats.append(f"- Sampling rate: 128 Hz")
stats.append(f"- Eyes Open samples: {eyes_open:,} ({eyes_open/total_samples*100:.1f}%)")
stats.append(f"- Eyes Closed samples: {eyes_closed:,} ({eyes_closed/total_samples*100:.1f}%)")
return "\n".join(stats)
def get_statistics_table():
"""
Generate statistics table per channel
"""
stats_data = []
for channel in eeg_channels:
channel_data = df[channel]
open_data = df[df['eyeDetection'] == 0][channel]
closed_data = df[df['eyeDetection'] == 1][channel]
stats_data.append({
'Channel': channel,
'Mean (All)': f"{channel_data.mean():.2f}",
'Std (All)': f"{channel_data.std():.2f}",
'Mean (Open)': f"{open_data.mean():.2f}",
'Mean (Closed)': f"{closed_data.mean():.2f}",
'Min': f"{channel_data.min():.2f}",
'Max': f"{channel_data.max():.2f}"
})
return pd.DataFrame(stats_data)
def plot_correlation_matrix():
"""
Visualize the correlation matrix between channels
"""
corr_matrix = df[eeg_channels].corr()
fig = go.Figure(data=go.Heatmap(
z=corr_matrix.values,
x=eeg_channels,
y=eeg_channels,
colorscale='RdBu',
zmid=0,
text=corr_matrix.values,
texttemplate='%{text:.2f}',
textfont={"size": 9},
colorbar=dict(title="Correlation")
))
fig.update_layout(
title={
'text': "EEG Channels Correlation Matrix",
'x': 0.5,
'xanchor': 'center'
},
height=600,
width=1215,
xaxis={'side': 'bottom'}
)
return fig
# Create Gradio interface
demo = gr.Blocks(title="EEG Eye State Visualizer")
with demo:
gr.Markdown("""
# 🧠 EEG Eye State Visualizer
Explore and visualize the EEG Eye State Classification Dataset. This interactive tool allows you to:
- View EEG signals from 14 channels
- Compare patterns between open and closed eyes
- Analyze statistical distributions
- Examine channel correlations
**Dataset Info**: 14,980 samples | 128 Hz sampling rate | 14 EEG channels
""")
with gr.Tab("Signal Viewer"):
gr.Markdown("### Visualize EEG Signals")
with gr.Row():
with gr.Column(scale=1):
start_time = gr.Slider(
minimum=0,
maximum=117,
value=0,
step=0.5,
label="Start Time (seconds)"
)
duration = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=0.5,
label="Duration (seconds)"
)
eye_state = gr.Radio(
choices=["Both", "Open", "Closed"],
value="Both",
label="Eye State Filter"
)
channels = gr.CheckboxGroup(
choices=eeg_channels,
value=['AF3', 'F7', 'O1', 'O2'],
label="Select Channels to Display"
)
plot_btn = gr.Button("Generate Plot", variant="primary")
with gr.Column(scale=3):
signal_plot = gr.Plot(label="EEG Signals")
plot_btn.click(
fn=plot_eeg_signals,
inputs=[start_time, duration, eye_state, channels],
outputs=signal_plot
)
with gr.Tab("Channel Analysis"):
gr.Markdown("### Compare Multiple Channels")
with gr.Row():
with gr.Column(scale=1):
channels_select = gr.CheckboxGroup(
choices=eeg_channels,
value=['AF3', 'O1'],
label="Select Channels to Compare"
)
eye_state_compare = gr.Radio(
choices=["Both", "Open", "Closed"],
value="Both",
label="Eye State Filter"
)
remove_outliers_check = gr.Checkbox(
label="Remove Outliers (IQR method)",
value=False
)
compare_btn = gr.Button("Analyze Channels", variant="primary")
with gr.Column(scale=3):
comparison_plot = gr.Plot(label="Channel Comparison")
compare_btn.click(
fn=plot_channel_comparison,
inputs=[channels_select, eye_state_compare, remove_outliers_check],
outputs=comparison_plot
)
with gr.Tab("Statistics"):
gr.Markdown("### Dataset Statistics")
stats_text = gr.Markdown(value=get_statistics())
gr.Markdown("### Channel Statistics Table (μV)")
stats_table = gr.Dataframe(
value=get_statistics_table(),
interactive=False,
wrap=True
)
gr.Markdown("### Correlation Matrix")
with gr.Row():
corr_plot = gr.Plot(
value=plot_correlation_matrix(),
container=True,
scale=1
)
with gr.Tab("About"):
gr.Markdown("""
## About this Dataset
The EEG Eye State Classification Dataset contains continuous EEG measurements from 14 electrodes
collected during different eye states (open/closed).
### Key Features:
- **Total Instances**: 14,980 observations
- **Features**: 14 EEG channel measurements
- **Sampling Rate**: 128 Hz
- **Duration**: ~117 seconds
- **Device**: Emotiv EEG Neuroheadset
### Electrode Placement:
The 14 channels follow the international 10-20 system:
- Left hemisphere: AF3, F7, F3, FC5, T7, P7, O1
- Right hemisphere: O2, P8, T8, FC6, F4, F8, AF4
### Citation:
```
Rösler, O. (2013). EEG Eye State.
UCI Machine Learning Repository.
https://doi.org/10.24432/C57G7J
```
### Links:
- [Dataset on Hugging Face](https://huggingface.co/datasets/BrainSpectralAnalytics/eeg-eye-state-classification)
- [Original UCI Repository](https://archive.ics.uci.edu/dataset/264/eeg+eye+state)
- [Kaggle Example](https://www.kaggle.com/code/beta3logic/eye-state-eeg-classification-model-using-automl)
""")
# Launch application
if __name__ == "__main__":
demo.launch(ssr_mode=False)