sofieff commited on
Commit
f8bae12
·
1 Parent(s): 66947ed

Stop tracking source directory

Browse files
Files changed (1) hide show
  1. source/eeg_motor_imagery.py +0 -142
source/eeg_motor_imagery.py DELETED
@@ -1,142 +0,0 @@
1
- """
2
- EEG Motor Imagery Classification with Shallow ConvNet
3
- -----------------------------------------------------
4
- This script trains and evaluates a ShallowFBCSPNet model
5
- on motor imagery EEG data stored in .mat files.
6
- """
7
-
8
- # === 1. Imports ===
9
- import scipy.io
10
- import numpy as np
11
- import mne
12
- import torch
13
- from torch.utils.data import TensorDataset, DataLoader
14
- from sklearn.model_selection import train_test_split
15
- from braindecode.models import ShallowFBCSPNet
16
- import torch.nn as nn
17
- import pandas as pd
18
-
19
- # === 2. Data Loading and Epoching ===
20
- # Load .mat EEG files, create Raw objects, extract events, and epoch the data
21
- files = [
22
- "../data/raw_mat/HaLTSubjectA1602236StLRHandLegTongue.mat",
23
- "../data/raw_mat/HaLTSubjectA1603086StLRHandLegTongue.mat",
24
-
25
- ]
26
-
27
- all_epochs = []
28
- for f in files:
29
- mat = scipy.io.loadmat(f)
30
- content = mat['o'][0, 0]
31
-
32
- labels = content[4].flatten()
33
- signals = content[5]
34
- chan_names_raw = content[6]
35
- channels = [ch[0][0] for ch in chan_names_raw]
36
- fs = int(content[2][0, 0])
37
-
38
- df = pd.DataFrame(signals, columns=channels).drop(columns=["X5"], errors="ignore")
39
- eeg = df.values.T
40
- ch_names = df.columns.tolist()
41
-
42
- info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types="eeg")
43
- raw = mne.io.RawArray(eeg, info)
44
-
45
- # Create events
46
- onsets = np.where((labels[1:] != 0) & (labels[:-1] == 0))[0] + 1
47
- event_codes = labels[onsets].astype(int)
48
- events = np.c_[onsets, np.zeros_like(onsets), event_codes]
49
-
50
- # Keep only relevant events
51
- mask = np.isin(events[:, 2], np.arange(1, 7))
52
- events = events[mask]
53
-
54
- event_id = {
55
- "left_hand": 1,
56
- "right_hand": 2,
57
- "neutral": 3,
58
- "left_leg": 4,
59
- "tongue": 5,
60
- "right_leg": 6,
61
- }
62
-
63
- # Epoching
64
- epochs = mne.Epochs(
65
- raw,
66
- events=events,
67
- event_id=event_id,
68
- tmin=0,
69
- tmax=1.5,
70
- baseline=None,
71
- preload=True,
72
- )
73
-
74
- all_epochs.append(epochs)
75
-
76
- epochs_all = mne.concatenate_epochs(all_epochs)
77
-
78
- # === 3. Minimal Preprocessing + Train/Validation Split ===
79
- # Convert epochs to numpy arrays (N, C, T) and split into train/val sets
80
- X = epochs_all.get_data().astype("float32")
81
- y = (epochs_all.events[:, -1] - 1).astype("int64") # classes 0..5
82
-
83
- X_train, X_val, y_train, y_val = train_test_split(
84
- X, y, test_size=0.2, random_state=42, stratify=y
85
- )
86
-
87
- # === 4. Torch DataLoaders ===
88
- train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
89
- val_ds = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
90
-
91
- train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
92
- val_loader = DataLoader(val_ds, batch_size=32)
93
-
94
- # === 5. Model – Shallow ConvNet ===
95
- # Reference: Schirrmeister et al. (2017)
96
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
-
98
- model = ShallowFBCSPNet(
99
- n_chans=X.shape[1],
100
- n_outputs=len(np.unique(y)),
101
- n_times=X.shape[2],
102
- final_conv_length="auto"
103
- ).to(device)
104
-
105
- # Load pretrained weights
106
- state_dict = torch.load("model.pth", map_location=device)
107
- model.load_state_dict(state_dict)
108
-
109
- # === 6. Training ===
110
- criterion = nn.CrossEntropyLoss()
111
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
112
-
113
- for epoch in range(1, 21):
114
- # Training
115
- model.train()
116
- correct, total = 0, 0
117
- for xb, yb in train_loader:
118
- xb, yb = xb.to(device), yb.to(device)
119
- optimizer.zero_grad()
120
- out = model(xb)
121
- loss = criterion(out, yb)
122
- loss.backward()
123
- optimizer.step()
124
-
125
- pred = out.argmax(dim=1)
126
- correct += (pred == yb).sum().item()
127
- total += yb.size(0)
128
- train_acc = correct / total
129
-
130
- # Validation
131
- model.eval()
132
- correct, total = 0, 0
133
- with torch.no_grad():
134
- for xb, yb in val_loader:
135
- xb, yb = xb.to(device), yb.to(device)
136
- out = model(xb)
137
- pred = out.argmax(dim=1)
138
- correct += (pred == yb).sum().item()
139
- total += yb.size(0)
140
- val_acc = correct / total
141
-
142
- print(f"Epoch {epoch:02d} | Train acc: {train_acc:.3f} | Val acc: {val_acc:.3f}")