Spaces:
Sleeping
Sleeping
Stop tracking source directory
Browse files- source/eeg_motor_imagery.py +0 -142
source/eeg_motor_imagery.py
DELETED
|
@@ -1,142 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
EEG Motor Imagery Classification with Shallow ConvNet
|
| 3 |
-
-----------------------------------------------------
|
| 4 |
-
This script trains and evaluates a ShallowFBCSPNet model
|
| 5 |
-
on motor imagery EEG data stored in .mat files.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
# === 1. Imports ===
|
| 9 |
-
import scipy.io
|
| 10 |
-
import numpy as np
|
| 11 |
-
import mne
|
| 12 |
-
import torch
|
| 13 |
-
from torch.utils.data import TensorDataset, DataLoader
|
| 14 |
-
from sklearn.model_selection import train_test_split
|
| 15 |
-
from braindecode.models import ShallowFBCSPNet
|
| 16 |
-
import torch.nn as nn
|
| 17 |
-
import pandas as pd
|
| 18 |
-
|
| 19 |
-
# === 2. Data Loading and Epoching ===
|
| 20 |
-
# Load .mat EEG files, create Raw objects, extract events, and epoch the data
|
| 21 |
-
files = [
|
| 22 |
-
"../data/raw_mat/HaLTSubjectA1602236StLRHandLegTongue.mat",
|
| 23 |
-
"../data/raw_mat/HaLTSubjectA1603086StLRHandLegTongue.mat",
|
| 24 |
-
|
| 25 |
-
]
|
| 26 |
-
|
| 27 |
-
all_epochs = []
|
| 28 |
-
for f in files:
|
| 29 |
-
mat = scipy.io.loadmat(f)
|
| 30 |
-
content = mat['o'][0, 0]
|
| 31 |
-
|
| 32 |
-
labels = content[4].flatten()
|
| 33 |
-
signals = content[5]
|
| 34 |
-
chan_names_raw = content[6]
|
| 35 |
-
channels = [ch[0][0] for ch in chan_names_raw]
|
| 36 |
-
fs = int(content[2][0, 0])
|
| 37 |
-
|
| 38 |
-
df = pd.DataFrame(signals, columns=channels).drop(columns=["X5"], errors="ignore")
|
| 39 |
-
eeg = df.values.T
|
| 40 |
-
ch_names = df.columns.tolist()
|
| 41 |
-
|
| 42 |
-
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types="eeg")
|
| 43 |
-
raw = mne.io.RawArray(eeg, info)
|
| 44 |
-
|
| 45 |
-
# Create events
|
| 46 |
-
onsets = np.where((labels[1:] != 0) & (labels[:-1] == 0))[0] + 1
|
| 47 |
-
event_codes = labels[onsets].astype(int)
|
| 48 |
-
events = np.c_[onsets, np.zeros_like(onsets), event_codes]
|
| 49 |
-
|
| 50 |
-
# Keep only relevant events
|
| 51 |
-
mask = np.isin(events[:, 2], np.arange(1, 7))
|
| 52 |
-
events = events[mask]
|
| 53 |
-
|
| 54 |
-
event_id = {
|
| 55 |
-
"left_hand": 1,
|
| 56 |
-
"right_hand": 2,
|
| 57 |
-
"neutral": 3,
|
| 58 |
-
"left_leg": 4,
|
| 59 |
-
"tongue": 5,
|
| 60 |
-
"right_leg": 6,
|
| 61 |
-
}
|
| 62 |
-
|
| 63 |
-
# Epoching
|
| 64 |
-
epochs = mne.Epochs(
|
| 65 |
-
raw,
|
| 66 |
-
events=events,
|
| 67 |
-
event_id=event_id,
|
| 68 |
-
tmin=0,
|
| 69 |
-
tmax=1.5,
|
| 70 |
-
baseline=None,
|
| 71 |
-
preload=True,
|
| 72 |
-
)
|
| 73 |
-
|
| 74 |
-
all_epochs.append(epochs)
|
| 75 |
-
|
| 76 |
-
epochs_all = mne.concatenate_epochs(all_epochs)
|
| 77 |
-
|
| 78 |
-
# === 3. Minimal Preprocessing + Train/Validation Split ===
|
| 79 |
-
# Convert epochs to numpy arrays (N, C, T) and split into train/val sets
|
| 80 |
-
X = epochs_all.get_data().astype("float32")
|
| 81 |
-
y = (epochs_all.events[:, -1] - 1).astype("int64") # classes 0..5
|
| 82 |
-
|
| 83 |
-
X_train, X_val, y_train, y_val = train_test_split(
|
| 84 |
-
X, y, test_size=0.2, random_state=42, stratify=y
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# === 4. Torch DataLoaders ===
|
| 88 |
-
train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
|
| 89 |
-
val_ds = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
|
| 90 |
-
|
| 91 |
-
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
|
| 92 |
-
val_loader = DataLoader(val_ds, batch_size=32)
|
| 93 |
-
|
| 94 |
-
# === 5. Model – Shallow ConvNet ===
|
| 95 |
-
# Reference: Schirrmeister et al. (2017)
|
| 96 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 97 |
-
|
| 98 |
-
model = ShallowFBCSPNet(
|
| 99 |
-
n_chans=X.shape[1],
|
| 100 |
-
n_outputs=len(np.unique(y)),
|
| 101 |
-
n_times=X.shape[2],
|
| 102 |
-
final_conv_length="auto"
|
| 103 |
-
).to(device)
|
| 104 |
-
|
| 105 |
-
# Load pretrained weights
|
| 106 |
-
state_dict = torch.load("model.pth", map_location=device)
|
| 107 |
-
model.load_state_dict(state_dict)
|
| 108 |
-
|
| 109 |
-
# === 6. Training ===
|
| 110 |
-
criterion = nn.CrossEntropyLoss()
|
| 111 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 112 |
-
|
| 113 |
-
for epoch in range(1, 21):
|
| 114 |
-
# Training
|
| 115 |
-
model.train()
|
| 116 |
-
correct, total = 0, 0
|
| 117 |
-
for xb, yb in train_loader:
|
| 118 |
-
xb, yb = xb.to(device), yb.to(device)
|
| 119 |
-
optimizer.zero_grad()
|
| 120 |
-
out = model(xb)
|
| 121 |
-
loss = criterion(out, yb)
|
| 122 |
-
loss.backward()
|
| 123 |
-
optimizer.step()
|
| 124 |
-
|
| 125 |
-
pred = out.argmax(dim=1)
|
| 126 |
-
correct += (pred == yb).sum().item()
|
| 127 |
-
total += yb.size(0)
|
| 128 |
-
train_acc = correct / total
|
| 129 |
-
|
| 130 |
-
# Validation
|
| 131 |
-
model.eval()
|
| 132 |
-
correct, total = 0, 0
|
| 133 |
-
with torch.no_grad():
|
| 134 |
-
for xb, yb in val_loader:
|
| 135 |
-
xb, yb = xb.to(device), yb.to(device)
|
| 136 |
-
out = model(xb)
|
| 137 |
-
pred = out.argmax(dim=1)
|
| 138 |
-
correct += (pred == yb).sum().item()
|
| 139 |
-
total += yb.size(0)
|
| 140 |
-
val_acc = correct / total
|
| 141 |
-
|
| 142 |
-
print(f"Epoch {epoch:02d} | Train acc: {train_acc:.3f} | Val acc: {val_acc:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|