""" EEG Motor Imagery Classifier Module ---------------------------------- Handles model loading, inference, and real-time prediction for motor imagery classification. Based on the ShallowFBCSPNet architecture from the original eeg_motor_imagery.py script. """ import torch import torch.nn as nn import numpy as np from braindecode.models.shallow_fbcsp import ShallowFBCSPNet from braindecode.modules.layers import Ensure4d # necessary for loading from typing import Dict, Tuple import os from data_processor import EEGDataProcessor from config import DEMO_DATA_PATHS class MotorImageryClassifier: """ Motor imagery classifier using ShallowFBCSPNet model. """ def __init__(self, model_path: str = "model.pth"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = None self.model_path = model_path self.class_names = { 0: "left_hand", 1: "right_hand", 2: "neutral", 3: "left_leg", 4: "tongue", 5: "right_leg" } self.is_loaded = False def load_model(self, n_chans: int, n_times: int, n_outputs: int = 6): """Load the pre-trained ShallowFBCSPNet model. If model file not found or incompatible, fallback to LOSO training. """ try: self.model = ShallowFBCSPNet( n_chans=n_chans, n_outputs=n_outputs, n_times=n_times, final_conv_length="auto" ).to(self.device) if os.path.exists(self.model_path): try: # Load only the state_dict, using weights_only=True and allowlist ShallowFBCSPNet with torch.serialization.safe_globals([Ensure4d, ShallowFBCSPNet]): checkpoint = torch.load( self.model_path, map_location=self.device, weights_only=False # must be False to allow objects ) # If checkpoint is a state_dict (dict of tensors) if isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()): self.model.load_state_dict(checkpoint) # If checkpoint is the full model object elif isinstance(checkpoint, ShallowFBCSPNet): self.model = checkpoint.to(self.device) else: raise ValueError("Unknown checkpoint format") #self.model.load_state_dict(state_dict) self.model.eval() self.is_loaded = True except Exception: self.is_loaded = False else: self.is_loaded = False except Exception: self.is_loaded = False def get_model_status(self) -> str: """Get current model status for user interface.""" if self.is_loaded: return "✅ Pre-trained model loaded and ready" else: return "🔄 Using LOSO training (training new model from EEG data)" def predict(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]: """ Predict motor imagery class from EEG data. Args: eeg_data: EEG data array of shape (n_channels, n_times) Returns: predicted_class: Predicted class index confidence: Confidence score probabilities: Dictionary of class probabilities """ if not self.is_loaded: return self._fallback_loso_classification(eeg_data) # Ensure input is the right shape: (batch, channels, time) if eeg_data.ndim == 2: eeg_data = eeg_data[np.newaxis, ...] # Convert to tensor x = torch.from_numpy(eeg_data.astype(np.float32)).to(self.device) with torch.no_grad(): output = self.model(x) probabilities = torch.softmax(output, dim=1) predicted_class = torch.argmax(probabilities, dim=1).cpu().numpy()[0] confidence = probabilities.max().cpu().numpy() # Convert to dictionary prob_dict = { self.class_names[i]: probabilities[0, i].cpu().numpy() for i in range(len(self.class_names)) } return predicted_class, confidence, prob_dict def _fallback_loso_classification(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]: """ Fallback classification using LOSO (Leave-One-Session-Out) training. Trains a model on available data when pre-trained model isn't available. """ try: # Initialize data processor processor = EEGDataProcessor() # Check if demo data files exist available_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)] if len(available_files) < 2: raise ValueError(f"Not enough data files for LOSO training. Need at least 2 files, found {len(available_files)}. " f"Available files: {available_files}") # Perform LOSO split (using first session as test) X_train, y_train, X_test, y_test, session_info = processor.prepare_loso_split( available_files, test_session_idx=0 ) # Get data dimensions n_chans = X_train.shape[1] n_times = X_train.shape[2] # Create and train model self.model = ShallowFBCSPNet( n_chans=n_chans, n_outputs=6, n_times=n_times, final_conv_length="auto" ).to(self.device) # Simple training loop optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() # Convert training data to tensors X_train_tensor = torch.from_numpy(X_train).float().to(self.device) y_train_tensor = torch.from_numpy(y_train).long().to(self.device) # Quick training (just a few epochs for demo) self.model.train() for epoch in range(50): optimizer.zero_grad() outputs = self.model(X_train_tensor) loss = criterion(outputs, y_train_tensor) loss.backward() optimizer.step() # Switch to evaluation mode self.model.eval() self.is_loaded = True # Now make prediction with the trained model return self.predict(eeg_data) except Exception as e: raise RuntimeError(f"Failed to initialize classifier. Neither pre-trained model nor LOSO training succeeded: {e}")