File size: 7,373 Bytes
fa96cf5
 
 
 
 
 
 
 
 
 
b906dc7
 
66947ed
fa96cf5
 
 
 
 
 
 
 
 
b906dc7
fa96cf5
 
 
 
 
 
 
 
 
 
 
 
 
 
66947ed
 
 
fa96cf5
 
 
 
 
 
 
 
 
 
b906dc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa96cf5
 
c6beb41
fa96cf5
 
 
 
c6beb41
fa96cf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""

EEG Motor Imagery Classifier Module

----------------------------------

Handles model loading, inference, and real-time prediction for motor imagery classification.

Based on the ShallowFBCSPNet architecture from the original eeg_motor_imagery.py script.

"""

import torch
import torch.nn as nn
import numpy as np
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.modules.layers import Ensure4d  # necessary for loading
from typing import Dict, Tuple
import os
from data_processor import EEGDataProcessor
from config import DEMO_DATA_PATHS

class MotorImageryClassifier:
    """

    Motor imagery classifier using ShallowFBCSPNet model.

    """
    
    def __init__(self, model_path: str = "model.pth"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        self.model_path = model_path
        self.class_names = {
            0: "left_hand",
            1: "right_hand", 
            2: "neutral",
            3: "left_leg",
            4: "tongue",
            5: "right_leg"
        }
        self.is_loaded = False
        
    def load_model(self, n_chans: int, n_times: int, n_outputs: int = 6):
        """Load the pre-trained ShallowFBCSPNet model.

        If model file not found or incompatible, fallback to LOSO training.

        """
        try:
            self.model = ShallowFBCSPNet(
                n_chans=n_chans,
                n_outputs=n_outputs,
                n_times=n_times,
                final_conv_length="auto"
            ).to(self.device)
            
            if os.path.exists(self.model_path):
                try:
                    # Load only the state_dict, using weights_only=True and allowlist ShallowFBCSPNet
                    with torch.serialization.safe_globals([Ensure4d, ShallowFBCSPNet]):
                        checkpoint = torch.load(
                        self.model_path,
                        map_location=self.device,
                        weights_only=False  # must be False to allow objects
                    )

                    # If checkpoint is a state_dict (dict of tensors)
                    if isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()):
                        self.model.load_state_dict(checkpoint)

                    # If checkpoint is the full model object
                    elif isinstance(checkpoint, ShallowFBCSPNet):
                        self.model = checkpoint.to(self.device)

                    else:
                        raise ValueError("Unknown checkpoint format")


                    #self.model.load_state_dict(state_dict)
                    self.model.eval()
                    self.is_loaded = True
                except Exception:
                    self.is_loaded = False
            else:
                self.is_loaded = False
                
        except Exception:
            self.is_loaded = False
    
    def get_model_status(self) -> str:
        """Get current model status for user interface."""
        if self.is_loaded:
            return "✅ Pre-trained model loaded and ready"
        else:
            return "🔄 Using LOSO training (training new model from EEG data)"
            
    def predict(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]:
        """

        Predict motor imagery class from EEG data.

        

        Args:

            eeg_data: EEG data array of shape (n_channels, n_times)

            

        Returns:

            predicted_class: Predicted class index

            confidence: Confidence score

            probabilities: Dictionary of class probabilities

        """
        if not self.is_loaded:
            return self._fallback_loso_classification(eeg_data)
            
        # Ensure input is the right shape: (batch, channels, time)
        if eeg_data.ndim == 2:
            eeg_data = eeg_data[np.newaxis, ...]
        
        # Convert to tensor
        x = torch.from_numpy(eeg_data.astype(np.float32)).to(self.device)
        
        with torch.no_grad():
            output = self.model(x)
            probabilities = torch.softmax(output, dim=1)
            predicted_class = torch.argmax(probabilities, dim=1).cpu().numpy()[0]
            confidence = probabilities.max().cpu().numpy()
            
            # Convert to dictionary
            prob_dict = {
                self.class_names[i]: probabilities[0, i].cpu().numpy()
                for i in range(len(self.class_names))
            }
            
        return predicted_class, confidence, prob_dict

    def _fallback_loso_classification(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]:
        """

        Fallback classification using LOSO (Leave-One-Session-Out) training.

        Trains a model on available data when pre-trained model isn't available.

        """
        try:
            
            # Initialize data processor
            processor = EEGDataProcessor()
            
            # Check if demo data files exist
            available_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)]
            if len(available_files) < 2:
                raise ValueError(f"Not enough data files for LOSO training. Need at least 2 files, found {len(available_files)}. "
                               f"Available files: {available_files}")
            
            # Perform LOSO split (using first session as test)
            X_train, y_train, X_test, y_test, session_info = processor.prepare_loso_split(
                available_files, test_session_idx=0
            )
            
            # Get data dimensions
            n_chans = X_train.shape[1]
            n_times = X_train.shape[2]
            
            # Create and train model
            self.model = ShallowFBCSPNet(
                n_chans=n_chans,
                n_outputs=6,
                n_times=n_times,
                final_conv_length="auto"
            ).to(self.device)
            
            # Simple training loop
            optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
            criterion = nn.CrossEntropyLoss()
            
            # Convert training data to tensors
            X_train_tensor = torch.from_numpy(X_train).float().to(self.device)
            y_train_tensor = torch.from_numpy(y_train).long().to(self.device)
            
            # Quick training (just a few epochs for demo)
            self.model.train()
            for epoch in range(50):
                optimizer.zero_grad()
                outputs = self.model(X_train_tensor)
                loss = criterion(outputs, y_train_tensor)
                loss.backward()
                optimizer.step()
                
            
            # Switch to evaluation mode
            self.model.eval()
            self.is_loaded = True
            
            
            # Now make prediction with the trained model
            return self.predict(eeg_data)
            
        except Exception as e:
            raise RuntimeError(f"Failed to initialize classifier. Neither pre-trained model nor LOSO training succeeded: {e}")