sofieff commited on
Commit
fa96cf5
Β·
0 Parent(s):

Initial commit: EEG Motor Imagery Music Composer

Browse files

- Gradio-based web application for EEG motor imagery music composition
- Real-time EEG signal processing and classification
- Interactive music building with 5 movement types (left/right hand, left/right leg, tongue)
- DJ effects phase with audio processing
- File saving currently disabled for testing
- Includes pre-trained shallow neural network model
- Audio state management to prevent Gradio component restarts

Files changed (13) hide show
  1. .gitattributes +3 -0
  2. .gitignore +59 -0
  3. README.md +78 -0
  4. app.py +1168 -0
  5. classifier.py +179 -0
  6. config.py +93 -0
  7. data_processor.py +257 -0
  8. demo.py +96 -0
  9. enhanced_utils.py +236 -0
  10. requirements.txt +14 -0
  11. sound_library.py +783 -0
  12. source/eeg_motor_imagery.py +142 -0
  13. utils.py +226 -0
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ data/*.mat filter=lfs diff=lfs merge=lfs -text
2
+ sounds/*.wav filter=lfs diff=lfs merge=lfs -text
3
+ model.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ build/
9
+ develop-eggs/
10
+ dist/
11
+ downloads/
12
+ eggs/
13
+ .eggs/
14
+ lib/
15
+ lib64/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ wheels/
20
+ *.egg-info/
21
+ .installed.cfg
22
+ *.egg
23
+
24
+ # Virtual Environment
25
+ .venv/
26
+ venv/
27
+ ENV/
28
+ env/
29
+
30
+ # IDE
31
+ .vscode/
32
+ .idea/
33
+ *.swp
34
+ *.swo
35
+ *~
36
+
37
+ # OS
38
+ .DS_Store
39
+ .DS_Store?
40
+ ._*
41
+ .Spotlight-V100
42
+ .Trashes
43
+ ehthumbs.db
44
+ Thumbs.db
45
+
46
+ # Project specific
47
+ *.wav
48
+ *.mp3
49
+ *.pth
50
+ mixed_composition_*.wav
51
+ *_fx_*.wav
52
+ app_old.py
53
+ app_new.py
54
+ test_mixed.wav
55
+ silent.wav
56
+
57
+ # Data files
58
+ data/
59
+ *.mat
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🧠 EEG Motor Imagery Music Composer
2
+
3
+ A sophisticated machine learning application that transforms brain signals into music compositions using motor imagery classification. This system uses a trained ShallowFBCSPNet model to classify different motor imagery tasks from EEG data and creates layered musical compositions based on the classification results.
4
+
5
+ ## 🎯 Features
6
+
7
+ - **Real-time EEG Classification**: Uses ShallowFBCSPNet architecture for motor imagery classification
8
+ - **Music Composition**: Automatically creates layered music compositions from classification results
9
+ - **Interactive Gradio Interface**: User-friendly web interface for real-time interaction
10
+ - **Six Motor Imagery Classes**: Left/right hand, left/right leg, tongue, and neutral states
11
+ - **Sound Mapping**: Each motor imagery class is mapped to different musical instruments
12
+ - **Composition Management**: Save, clear, and manage your musical creations
13
+
14
+ ## πŸ—οΈ Architecture
15
+
16
+ ### Project Structure
17
+ ```
18
+ β”œβ”€β”€ app.py # Main Gradio application
19
+ β”œβ”€β”€ classifier.py # Motor imagery classifier with ShallowFBCSPNet
20
+ β”œβ”€β”€ data_processor.py # EEG data loading and preprocessing
21
+ β”œβ”€β”€ sound_library.py # Sound management and composition system
22
+ β”œβ”€β”€ config.py # Configuration settings
23
+ β”œβ”€β”€ requirements.txt # Python dependencies
24
+ β”œβ”€β”€ SoundHelix-Song-1/ # Audio files for different instruments
25
+ β”‚ β”œβ”€β”€ bass.wav
26
+ β”‚ β”œβ”€β”€ drums.wav
27
+ β”‚ β”œβ”€β”€ other.wav
28
+ β”‚ └── vocals.wav
29
+ └── src/ # Additional source files
30
+ β”œβ”€β”€ model.py
31
+ β”œβ”€β”€ preprocessing.py
32
+ β”œβ”€β”€ train.py
33
+ └── visualize.py
34
+ ```
35
+
36
+ ### System Components
37
+
38
+ 1. **EEGDataProcessor** (`data_processor.py`)
39
+ - Loads and processes .mat EEG files
40
+ - Handles epoching and preprocessing
41
+ - Simulates real-time data for demo purposes
42
+
43
+ 2. **MotorImageryClassifier** (`classifier.py`)
44
+ - Implements ShallowFBCSPNet model
45
+ - Performs real-time classification
46
+ - Provides confidence scores and probability distributions
47
+
48
+ 3. **SoundManager** (`sound_library.py`)
49
+ - Maps classifications to audio files
50
+ - Manages composition layers
51
+ - Handles audio file loading and playback
52
+
53
+ 4. **Gradio Interface** (`app.py`)
54
+ - Web-based user interface
55
+ - Real-time visualization
56
+ - Composition management tools
57
+
58
+ ## πŸš€ Quick Start
59
+
60
+ ### Requirements
61
+
62
+ Python 3.9–3.11 recommended. Install dependencies:
63
+
64
+ ```bash
65
+ python -m pip install -r requirements.txt
66
+ ```
67
+
68
+ ### How to run (Gradio)
69
+
70
+ Local launch:
71
+
72
+ ```bash
73
+ python app.py
74
+ ```
75
+
76
+ This starts a server on `http://127.0.0.1:7860` by default.
77
+
78
+ #
app.py ADDED
@@ -0,0 +1,1168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EEG Motor Imagery Music Composer - Redesigned Interface
3
+ =======================================================
4
+ Brain-Computer Interface that creates music compositions based on imagined movements.
5
+ """
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import time
11
+ import threading
12
+ from typing import Dict, Tuple, Any, List
13
+
14
+ # Import our custom modules
15
+ from sound_library import SoundManager
16
+ from data_processor import EEGDataProcessor
17
+ from classifier import MotorImageryClassifier
18
+ from config import DEMO_DATA_PATHS, CLASS_NAMES, CONFIDENCE_THRESHOLD
19
+
20
+ def validate_data_setup() -> str:
21
+ """Validate that required data files are available."""
22
+ missing_files = []
23
+
24
+ for subject_id, path in DEMO_DATA_PATHS.items():
25
+ try:
26
+ import os
27
+ if not os.path.exists(path):
28
+ missing_files.append(f"Subject {subject_id}: {path}")
29
+ except Exception as e:
30
+ missing_files.append(f"Subject {subject_id}: Error checking {path}")
31
+
32
+ if missing_files:
33
+ return f"❌ Missing data files:\n" + "\n".join(missing_files)
34
+ return "βœ… All data files found"
35
+
36
+ # Global app state
37
+ app_state = {
38
+ 'is_running': False,
39
+ 'demo_data': None,
40
+ 'demo_labels': None,
41
+ 'classification_history': [],
42
+ 'composition_active': False,
43
+ 'auto_mode': False,
44
+ 'last_audio_state': {
45
+ 'left_hand_audio': None,
46
+ 'right_hand_audio': None,
47
+ 'left_leg_audio': None,
48
+ 'right_leg_audio': None,
49
+ 'tongue_audio': None
50
+ }
51
+ }
52
+
53
+ # Initialize components
54
+ print("🧠 EEG Motor Imagery Music Composer")
55
+ print("=" * 50)
56
+ print("Starting Gradio application...")
57
+
58
+ try:
59
+ sound_manager = SoundManager()
60
+ data_processor = EEGDataProcessor()
61
+ classifier = MotorImageryClassifier()
62
+
63
+ # Load demo data
64
+ import os
65
+ existing_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)]
66
+ if existing_files:
67
+ app_state['demo_data'], app_state['demo_labels'] = data_processor.process_files(existing_files)
68
+ else:
69
+ app_state['demo_data'], app_state['demo_labels'] = None, None
70
+
71
+ if app_state['demo_data'] is not None:
72
+ # Initialize classifier with proper dimensions
73
+ classifier.load_model(n_chans=app_state['demo_data'].shape[1], n_times=app_state['demo_data'].shape[2])
74
+ print(f"βœ… Pre-trained model loaded successfully from {classifier.model_path}")
75
+ print(f"Pre-trained Demo: {len(app_state['demo_data'])} samples from {len(existing_files)} subjects")
76
+ else:
77
+ print("⚠️ No demo data loaded - check your .mat files")
78
+
79
+ print(f"Available sound classes: {list(sound_manager.current_sound_mapping.keys())}")
80
+
81
+ except Exception as e:
82
+ print(f"❌ Error during initialization: {e}")
83
+ raise RuntimeError(
84
+ "Cannot initialize app without real EEG data. "
85
+ "Please check your data files and paths."
86
+ )
87
+
88
+ def get_movement_sounds() -> Dict[str, str]:
89
+ """Get the current sound files for each movement."""
90
+ sounds = {}
91
+ for movement, sound_file in sound_manager.current_sound_mapping.items():
92
+ if movement in ['left_hand', 'right_hand', 'left_leg', "right_leg", 'tongue']: # Only show main movements
93
+ if sound_file is not None: # Check if sound_file is not None
94
+ sound_path = sound_manager.sound_dir / sound_file
95
+ if sound_path.exists():
96
+ # Convert to absolute path for Gradio audio components
97
+ sounds[movement] = str(sound_path.resolve())
98
+ return sounds
99
+
100
+ def start_composition():
101
+ """Start the composition process and perform initial classification."""
102
+ global app_state
103
+
104
+ # Only start new cycle if not already active
105
+ if not app_state['composition_active']:
106
+ app_state['composition_active'] = True
107
+ sound_manager.start_new_cycle() # Reset composition only when starting fresh
108
+
109
+ if app_state['demo_data'] is None:
110
+ return "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available"
111
+
112
+ # Get current target
113
+ target_movement = sound_manager.get_current_target_movement()
114
+ print(f"DEBUG start_composition: current target = {target_movement}")
115
+
116
+ # Check if cycle is complete
117
+ if target_movement == "cycle_complete":
118
+ return "🎡 Cycle Complete!", "🎡 Complete", "Remap sounds to continue", None, None, None, None, None, None, "Cycle complete - remap sounds to continue"
119
+
120
+ # Perform initial EEG classification
121
+ epoch_data, true_label = data_processor.simulate_real_time_data(
122
+ app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
123
+ )
124
+
125
+ # Classify the epoch
126
+ predicted_class, confidence, probabilities = classifier.predict(epoch_data)
127
+ predicted_name = classifier.class_names[predicted_class]
128
+
129
+ # Process classification
130
+ result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
131
+
132
+ # Create visualization
133
+ fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
134
+
135
+ # Initialize all audio components to silent (to clear any previous sounds)
136
+ silent_file = "silent.wav"
137
+ left_hand_audio = silent_file
138
+ right_hand_audio = silent_file
139
+ left_leg_audio = silent_file
140
+ right_leg_audio = silent_file
141
+ tongue_audio = silent_file
142
+
143
+ # Debug: Print classification result
144
+ print(f"DEBUG start_composition: predicted={predicted_name}, confidence={confidence:.3f}, sound_added={result['sound_added']}")
145
+
146
+ # Only play the sound if it was just added and matches the prediction
147
+ if result['sound_added']:
148
+ sounds = get_movement_sounds()
149
+ print(f"DEBUG: Available sounds: {list(sounds.keys())}")
150
+ if predicted_name == 'left_hand' and 'left_hand' in sounds:
151
+ left_hand_audio = sounds['left_hand']
152
+ print(f"DEBUG: Setting left_hand_audio to {sounds['left_hand']}")
153
+ elif predicted_name == 'right_hand' and 'right_hand' in sounds:
154
+ right_hand_audio = sounds['right_hand']
155
+ print(f"DEBUG: Setting right_hand_audio to {sounds['right_hand']}")
156
+ elif predicted_name == 'left_leg' and 'left_leg' in sounds:
157
+ left_leg_audio = sounds['left_leg']
158
+ print(f"DEBUG: Setting left_leg_audio to {sounds['left_leg']}")
159
+ elif predicted_name == 'right_leg' and 'right_leg' in sounds:
160
+ right_leg_audio = sounds['right_leg']
161
+ print(f"DEBUG: Setting right_leg_audio to {sounds['right_leg']}")
162
+ elif predicted_name == 'tongue' and 'tongue' in sounds:
163
+ tongue_audio = sounds['tongue']
164
+ print(f"DEBUG: Setting tongue_audio to {sounds['tongue']}")
165
+ else:
166
+ print("DEBUG: No sound added - confidence too low or other issue")
167
+
168
+ # Format next target with progress information
169
+ next_target = sound_manager.get_current_target_movement()
170
+ completed_count = len(sound_manager.movements_completed)
171
+ total_count = len(sound_manager.current_movement_sequence)
172
+
173
+ if next_target == "cycle_complete":
174
+ target_text = "🎡 Cycle Complete!"
175
+ else:
176
+ target_text = f"🎯 Any Movement ({completed_count}/{total_count} complete) - Use 'Classify Epoch' button to continue"
177
+
178
+ predicted_text = f"🧠 Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
179
+
180
+ # Get composition info
181
+ composition_info = sound_manager.get_composition_info()
182
+ status_text = format_composition_summary(composition_info)
183
+
184
+ return (
185
+ target_text,
186
+ predicted_text,
187
+ "2-3 seconds",
188
+ fig,
189
+ left_hand_audio,
190
+ right_hand_audio,
191
+ left_leg_audio,
192
+ right_leg_audio,
193
+ tongue_audio,
194
+ status_text
195
+ )
196
+
197
+ def stop_composition():
198
+ """Stop the composition process."""
199
+ global app_state
200
+ app_state['composition_active'] = False
201
+ app_state['auto_mode'] = False
202
+ return (
203
+ "Composition stopped. Click 'Start Composing' to begin again",
204
+ "--",
205
+ "--",
206
+ "Stopped - click Start to resume"
207
+ )
208
+
209
+ def start_automatic_composition():
210
+ """Start automatic composition with continuous classification."""
211
+ global app_state
212
+
213
+ # Only start new cycle if not already active
214
+ if not app_state['composition_active']:
215
+ app_state['composition_active'] = True
216
+ app_state['auto_mode'] = True
217
+ sound_manager.start_new_cycle() # Reset composition only when starting fresh
218
+
219
+ if app_state['demo_data'] is None:
220
+ return "❌ No data", "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available"
221
+
222
+ # Get current target
223
+ target_movement = sound_manager.get_current_target_movement()
224
+ print(f"DEBUG start_automatic_composition: current target = {target_movement}")
225
+
226
+ # Check if cycle is complete
227
+ if target_movement == "cycle_complete":
228
+ # Mark current cycle as complete
229
+ sound_manager.complete_current_cycle()
230
+
231
+ # Check if rehabilitation session should end
232
+ if sound_manager.should_end_session():
233
+ app_state['auto_mode'] = False # Stop automatic mode
234
+ return (
235
+ "πŸŽ‰ Session Complete!",
236
+ "πŸ† Amazing Progress!",
237
+ "Rehabilitation session finished!",
238
+ "🌟 Congratulations! You've created 2 unique brain-music compositions!\n\n" +
239
+ "πŸ’ͺ Your motor imagery skills are improving!\n\n" +
240
+ "🎡 You can review your compositions above, or start a new session anytime.\n\n" +
241
+ "Would you like to continue with more cycles, or take a well-deserved break?",
242
+ None, None, None, None, None, None,
243
+ f"βœ… Session Complete: {sound_manager.completed_cycles}/{sound_manager.max_cycles} compositions finished!"
244
+ )
245
+ else:
246
+ # Start next cycle automatically
247
+ sound_manager.start_new_cycle()
248
+ print("πŸ”„ Cycle completed! Starting new cycle automatically...")
249
+ target_movement = sound_manager.get_current_target_movement() # Get new target
250
+
251
+ # Show user prompt - encouraging start message
252
+ cycle_num = sound_manager.current_cycle
253
+ if cycle_num == 1:
254
+ prompt_text = "🌟 Welcome to your rehabilitation session! Let's start with any movement you can imagine..."
255
+ elif cycle_num == 2:
256
+ prompt_text = "πŸ’ͺ Excellent work on your first composition! Ready for composition #2?"
257
+ else:
258
+ prompt_text = "🧠 Let's continue - imagine any movement now..."
259
+
260
+ # Perform initial EEG classification
261
+ epoch_data, true_label = data_processor.simulate_real_time_data(
262
+ app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
263
+ )
264
+
265
+ # Classify the epoch
266
+ predicted_class, confidence, probabilities = classifier.predict(epoch_data)
267
+ predicted_name = classifier.class_names[predicted_class]
268
+
269
+ # Handle DJ effects or building phase
270
+ if sound_manager.current_phase == "dj_effects" and confidence > CONFIDENCE_THRESHOLD:
271
+ # DJ Effects Mode - toggle effects instead of adding sounds
272
+ dj_result = sound_manager.toggle_dj_effect(predicted_name)
273
+ result = {
274
+ 'sound_added': dj_result['effect_applied'],
275
+ 'mixed_composition': dj_result.get('mixed_composition'),
276
+ 'effect_name': dj_result.get('effect_name', ''),
277
+ 'effect_status': dj_result.get('effect_status', '')
278
+ }
279
+ else:
280
+ # Building Mode - process classification normally
281
+ result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
282
+
283
+ # Create visualization
284
+ fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
285
+
286
+ # Initialize all audio components to silent by default
287
+ silent_file = "silent.wav"
288
+ left_hand_audio = silent_file
289
+ right_hand_audio = silent_file
290
+ left_leg_audio = silent_file
291
+ right_leg_audio = silent_file
292
+ tongue_audio = silent_file
293
+
294
+ # Debug: Print classification result
295
+ print(f"DEBUG start_automatic_composition: predicted={predicted_name}, confidence={confidence:.3f}, sound_added={result['sound_added']}")
296
+
297
+ # Handle audio display based on current phase
298
+ if sound_manager.current_phase == "dj_effects":
299
+ # DJ Effects Phase - show mixed composition with effects
300
+ if result.get('mixed_composition'):
301
+ # Show the mixed composition (all sounds combined) in one player
302
+ left_hand_audio = result['mixed_composition'] # Use first player for mixed audio
303
+ print(f"DEBUG DJ: Playing mixed composition with effects: {result['mixed_composition']}")
304
+ else:
305
+ # Fallback to showing accumulated sounds
306
+ sounds = get_movement_sounds()
307
+ completed_movements = sound_manager.movements_completed
308
+ if 'left_hand' in completed_movements and 'left_hand' in sounds:
309
+ left_hand_audio = sounds['left_hand']
310
+ else:
311
+ # Building Phase - show ALL accumulated sounds (layered composition)
312
+ sounds = get_movement_sounds()
313
+ completed_movements = sound_manager.movements_completed
314
+ print(f"DEBUG: Available sounds: {list(sounds.keys())}")
315
+ print(f"DEBUG: Completed movements: {completed_movements}")
316
+
317
+ # Display all completed movement sounds (cumulative layering)
318
+ if 'left_hand' in completed_movements and 'left_hand' in sounds:
319
+ left_hand_audio = sounds['left_hand']
320
+ print(f"DEBUG: Showing accumulated left_hand_audio: {sounds['left_hand']}")
321
+ if 'right_hand' in completed_movements and 'right_hand' in sounds:
322
+ right_hand_audio = sounds['right_hand']
323
+ print(f"DEBUG: Showing accumulated right_hand_audio: {sounds['right_hand']}")
324
+ if 'left_leg' in completed_movements and 'left_leg' in sounds:
325
+ left_leg_audio = sounds['left_leg']
326
+ print(f"DEBUG: Showing accumulated left_leg_audio: {sounds['left_leg']}")
327
+ if 'right_leg' in completed_movements and 'right_leg' in sounds:
328
+ right_leg_audio = sounds['right_leg']
329
+ print(f"DEBUG: Showing accumulated right_leg_audio: {sounds['right_leg']}")
330
+ if 'tongue' in completed_movements and 'tongue' in sounds:
331
+ tongue_audio = sounds['tongue']
332
+ print(f"DEBUG: Showing accumulated tongue_audio: {sounds['tongue']}")
333
+
334
+ # If a sound was just added, make sure it's included immediately
335
+ if result['sound_added'] and predicted_name in sounds:
336
+ if predicted_name == 'left_hand':
337
+ left_hand_audio = sounds['left_hand']
338
+ elif predicted_name == 'right_hand':
339
+ right_hand_audio = sounds['right_hand']
340
+ elif predicted_name == 'left_leg':
341
+ left_leg_audio = sounds['left_leg']
342
+ elif predicted_name == 'right_leg':
343
+ right_leg_audio = sounds['right_leg']
344
+ elif predicted_name == 'tongue':
345
+ tongue_audio = sounds['tongue']
346
+ print(f"DEBUG: Just added {predicted_name} sound: {sounds[predicted_name]}")
347
+
348
+ # Check for phase transition to DJ effects
349
+ completed_count = len(sound_manager.movements_completed)
350
+ total_count = len(sound_manager.current_movement_sequence)
351
+
352
+ # Transition to DJ effects if all movements completed but still in building phase
353
+ if completed_count >= 5 and sound_manager.current_phase == "building":
354
+ sound_manager.transition_to_dj_phase()
355
+
356
+ # Format display based on current phase
357
+ if sound_manager.current_phase == "dj_effects":
358
+ target_text = "🎧 DJ Mode Active - Use movements to control effects!"
359
+ else:
360
+ next_target = sound_manager.get_current_target_movement()
361
+ if next_target == "cycle_complete":
362
+ target_text = "🎡 Composition Complete!"
363
+ else:
364
+ target_text = f"🎯 Building Composition ({completed_count}/{total_count} layers)"
365
+
366
+ # Update display text based on phase
367
+ if sound_manager.current_phase == "dj_effects":
368
+ if result.get('effect_name') and result.get('effect_status'):
369
+ predicted_text = f"πŸŽ›οΈ {result['effect_name']}: {result['effect_status']}"
370
+ else:
371
+ predicted_text = f"🧠 Detected: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
372
+ timer_text = "🎧 DJ Mode - Effects updating every 3 seconds..."
373
+ else:
374
+ predicted_text = f"🧠 Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
375
+ timer_text = "⏱️ Next trial in 2-3 seconds..."
376
+
377
+ # Get composition info
378
+ composition_info = sound_manager.get_composition_info()
379
+ status_text = format_composition_summary(composition_info)
380
+
381
+ # Phase-based instruction visibility
382
+ building_visible = sound_manager.current_phase == "building"
383
+ dj_visible = sound_manager.current_phase == "dj_effects"
384
+
385
+ return (
386
+ target_text,
387
+ predicted_text,
388
+ timer_text,
389
+ prompt_text,
390
+ fig,
391
+ left_hand_audio,
392
+ right_hand_audio,
393
+ left_leg_audio,
394
+ right_leg_audio,
395
+ tongue_audio,
396
+ status_text,
397
+ gr.update(visible=building_visible), # building_instructions
398
+ gr.update(visible=dj_visible) # dj_instructions
399
+ )
400
+
401
+ def manual_classify():
402
+ """Manual classification for testing purposes."""
403
+ global app_state
404
+
405
+ if app_state['demo_data'] is None:
406
+ return "❌ No data", "❌ No data", "Manual mode", None, "No EEG data available", None, None, None, None, None
407
+
408
+ # Get EEG data sample
409
+ epoch_data, true_label = data_processor.simulate_real_time_data(
410
+ app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
411
+ )
412
+
413
+ # Classify the epoch
414
+ predicted_class, confidence, probabilities = classifier.predict(epoch_data)
415
+ predicted_name = classifier.class_names[predicted_class]
416
+
417
+ # Create visualization (without composition context)
418
+ fig = create_eeg_plot(epoch_data, "manual_test", predicted_name, confidence, False)
419
+
420
+ # Format results
421
+ target_text = "🎯 Manual Test Mode"
422
+ predicted_text = f"🧠 {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
423
+
424
+ # Update results log
425
+ import time
426
+ timestamp = time.strftime("%H:%M:%S")
427
+ result_entry = f"[{timestamp}] Predicted: {predicted_name.replace('_', ' ').title()} (confidence: {confidence:.3f})"
428
+
429
+ # Get sound files for preview (no autoplay)
430
+ sounds = get_movement_sounds()
431
+ left_hand_audio = sounds.get('left_hand', None)
432
+ right_hand_audio = sounds.get('right_hand', None)
433
+ left_leg_audio = sounds.get('left_leg', None)
434
+ right_leg_audio = sounds.get('right_leg', None)
435
+ tongue_audio = sounds.get('tongue', None)
436
+
437
+ return (
438
+ target_text,
439
+ predicted_text,
440
+ "Manual mode - click button to classify",
441
+ fig,
442
+ result_entry,
443
+ left_hand_audio,
444
+ right_hand_audio,
445
+ left_leg_audio,
446
+ right_leg_audio,
447
+ tongue_audio
448
+ )
449
+
450
+ def clear_manual():
451
+ """Clear manual testing results."""
452
+ return (
453
+ "🎯 Manual Test Mode",
454
+ "--",
455
+ "Manual mode",
456
+ None,
457
+ "Manual classification results cleared...",
458
+ None, None, None, None, None
459
+ )
460
+
461
+ def continue_automatic_composition():
462
+ """Continue automatic composition - called for subsequent trials."""
463
+ global app_state
464
+
465
+ if not app_state['composition_active'] or not app_state['auto_mode']:
466
+ return "πŸ›‘ Stopped", "--", "--", "Automatic composition stopped", None, None, None, None, None, None, "Stopped", gr.update(visible=True), gr.update(visible=False)
467
+
468
+ if app_state['demo_data'] is None:
469
+ return "❌ No data", "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available", gr.update(visible=True), gr.update(visible=False)
470
+
471
+ # Get current target
472
+ target_movement = sound_manager.get_current_target_movement()
473
+ print(f"DEBUG continue_automatic_composition: current target = {target_movement}")
474
+
475
+ # Check if cycle is complete
476
+ if target_movement == "cycle_complete":
477
+ # Mark current cycle as complete
478
+ sound_manager.complete_current_cycle()
479
+
480
+ # Check if rehabilitation session should end
481
+ if sound_manager.should_end_session():
482
+ app_state['auto_mode'] = False # Stop automatic mode
483
+ return (
484
+ "πŸŽ‰ Session Complete!",
485
+ "πŸ† Amazing Progress!",
486
+ "Rehabilitation session finished!",
487
+ "🌟 Congratulations! You've created 2 unique brain-music compositions!\n\n" +
488
+ "πŸ’ͺ Your motor imagery skills are improving!\n\n" +
489
+ "🎡 You can review your compositions above, or start a new session anytime.\n\n" +
490
+ "Would you like to continue with more cycles, or take a well-deserved break?",
491
+ None, None, None, None, None, None,
492
+ f"βœ… Session Complete: {sound_manager.completed_cycles}/{sound_manager.max_cycles} compositions finished!",
493
+ gr.update(visible=True), gr.update(visible=False)
494
+ )
495
+ else:
496
+ # Start next cycle automatically
497
+ sound_manager.start_new_cycle()
498
+ print("πŸ”„ Cycle completed! Starting new cycle automatically...")
499
+ target_movement = sound_manager.get_current_target_movement() # Get new target
500
+
501
+ # Show next user prompt - rehabilitation-focused messaging
502
+ prompts = [
503
+ "πŸ’ͺ Great work! Imagine your next movement...",
504
+ "🎯 You're doing amazing! Focus and imagine any movement...",
505
+ "✨ Excellent progress! Ready for the next movement?",
506
+ "🌟 Keep it up! Concentrate and imagine now...",
507
+ "πŸ† Fantastic! Next trial - imagine any movement..."
508
+ ]
509
+ import random
510
+ prompt_text = random.choice(prompts)
511
+
512
+ # Add progress encouragement
513
+ completed_count = len(sound_manager.movements_completed)
514
+ total_count = len(sound_manager.current_movement_sequence)
515
+ if completed_count > 0:
516
+ prompt_text += f" ({completed_count}/{total_count} movements completed this cycle)"
517
+
518
+ # Perform EEG classification
519
+ epoch_data, true_label = data_processor.simulate_real_time_data(
520
+ app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
521
+ )
522
+
523
+ # Classify the epoch
524
+ predicted_class, confidence, probabilities = classifier.predict(epoch_data)
525
+ predicted_name = classifier.class_names[predicted_class]
526
+
527
+ # Handle DJ effects or building phase
528
+ if sound_manager.current_phase == "dj_effects" and confidence > CONFIDENCE_THRESHOLD:
529
+ # DJ Effects Mode - toggle effects instead of adding sounds
530
+ dj_result = sound_manager.toggle_dj_effect(predicted_name)
531
+ result = {
532
+ 'sound_added': dj_result['effect_applied'],
533
+ 'mixed_composition': dj_result.get('mixed_composition'),
534
+ 'effect_name': dj_result.get('effect_name', ''),
535
+ 'effect_status': dj_result.get('effect_status', '')
536
+ }
537
+ else:
538
+ # Building Mode - process classification normally
539
+ result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
540
+
541
+ # Check if we should transition to DJ phase
542
+ completed_count = len(sound_manager.movements_completed)
543
+ if completed_count >= 5 and sound_manager.current_phase == "building":
544
+ if sound_manager.transition_to_dj_phase():
545
+ print(f"DEBUG: Successfully transitioned to DJ phase with {completed_count} completed movements")
546
+
547
+ # Create visualization
548
+ fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
549
+
550
+ # Initialize all audio components to silent by default
551
+ silent_file = "silent.wav"
552
+ left_hand_audio = silent_file
553
+ right_hand_audio = silent_file
554
+ left_leg_audio = silent_file
555
+ right_leg_audio = silent_file
556
+ tongue_audio = silent_file
557
+
558
+ # Handle audio differently based on phase
559
+ if sound_manager.current_phase == "dj_effects":
560
+ # DJ Mode: Play mixed composition with effects in the center position (tongue)
561
+ # Keep individual tracks silent to avoid overlapping audio
562
+ # Only update audio if the mixed composition file has actually changed
563
+ if sound_manager.mixed_composition_file and os.path.exists(sound_manager.mixed_composition_file):
564
+ # Only update if the file path has changed from last time
565
+ if app_state['last_audio_state']['tongue_audio'] != sound_manager.mixed_composition_file:
566
+ tongue_audio = sound_manager.mixed_composition_file
567
+ app_state['last_audio_state']['tongue_audio'] = sound_manager.mixed_composition_file
568
+ print(f"DEBUG continue: DJ mode - NEW mixed composition loaded: {sound_manager.mixed_composition_file}")
569
+ elif app_state['last_audio_state']['tongue_audio'] is not None:
570
+ # Use the same file as before to prevent Gradio from restarting audio
571
+ tongue_audio = app_state['last_audio_state']['tongue_audio']
572
+ # No debug print to reduce spam
573
+ else:
574
+ # Handle case where cached state is None - use the current file
575
+ tongue_audio = sound_manager.mixed_composition_file
576
+ app_state['last_audio_state']['tongue_audio'] = sound_manager.mixed_composition_file
577
+ print(f"DEBUG continue: DJ mode - Loading mixed composition (state was None): {sound_manager.mixed_composition_file}")
578
+ # Note: We don't update with processed effects files to prevent audio restarts
579
+ else:
580
+ # Building Mode: Show ALL accumulated sounds (layered composition)
581
+ sounds = get_movement_sounds()
582
+ completed_movements = sound_manager.movements_completed
583
+ print(f"DEBUG continue: Available sounds: {list(sounds.keys())}")
584
+ print(f"DEBUG continue: Completed movements: {completed_movements}")
585
+
586
+ # Display all completed movement sounds (cumulative layering) - only update when changed
587
+ if 'left_hand' in completed_movements and 'left_hand' in sounds:
588
+ new_left_hand = sounds['left_hand']
589
+ if app_state['last_audio_state']['left_hand_audio'] != new_left_hand:
590
+ left_hand_audio = new_left_hand
591
+ app_state['last_audio_state']['left_hand_audio'] = new_left_hand
592
+ print(f"DEBUG continue: NEW left_hand_audio: {new_left_hand}")
593
+ elif app_state['last_audio_state']['left_hand_audio'] is not None:
594
+ left_hand_audio = app_state['last_audio_state']['left_hand_audio']
595
+ else:
596
+ # Handle case where cached state is None - use the current file
597
+ left_hand_audio = new_left_hand
598
+ app_state['last_audio_state']['left_hand_audio'] = new_left_hand
599
+
600
+ if 'right_hand' in completed_movements and 'right_hand' in sounds:
601
+ new_right_hand = sounds['right_hand']
602
+ # TEMP FIX: Always update right_hand to test audio issue
603
+ right_hand_audio = new_right_hand
604
+ app_state['last_audio_state']['right_hand_audio'] = new_right_hand
605
+ print(f"DEBUG continue: ALWAYS UPDATE right_hand_audio: {new_right_hand}")
606
+
607
+ if 'left_leg' in completed_movements and 'left_leg' in sounds:
608
+ new_left_leg = sounds['left_leg']
609
+ if app_state['last_audio_state']['left_leg_audio'] != new_left_leg:
610
+ left_leg_audio = new_left_leg
611
+ app_state['last_audio_state']['left_leg_audio'] = new_left_leg
612
+ print(f"DEBUG continue: NEW left_leg_audio: {new_left_leg}")
613
+ elif app_state['last_audio_state']['left_leg_audio'] is not None:
614
+ left_leg_audio = app_state['last_audio_state']['left_leg_audio']
615
+ else:
616
+ # Handle case where cached state is None - use the current file
617
+ left_leg_audio = new_left_leg
618
+ app_state['last_audio_state']['left_leg_audio'] = new_left_leg
619
+
620
+ if 'right_leg' in completed_movements and 'right_leg' in sounds:
621
+ new_right_leg = sounds['right_leg']
622
+ if app_state['last_audio_state']['right_leg_audio'] != new_right_leg:
623
+ right_leg_audio = new_right_leg
624
+ app_state['last_audio_state']['right_leg_audio'] = new_right_leg
625
+ print(f"DEBUG continue: NEW right_leg_audio: {new_right_leg}")
626
+ elif app_state['last_audio_state']['right_leg_audio'] is not None:
627
+ right_leg_audio = app_state['last_audio_state']['right_leg_audio']
628
+ else:
629
+ # Handle case where cached state is None - use the current file
630
+ right_leg_audio = new_right_leg
631
+ app_state['last_audio_state']['right_leg_audio'] = new_right_leg
632
+
633
+ if 'tongue' in completed_movements and 'tongue' in sounds:
634
+ new_tongue = sounds['tongue']
635
+ # Note: Don't update tongue audio in building mode if we're about to transition to DJ mode
636
+ if sound_manager.current_phase == "building":
637
+ if app_state['last_audio_state']['tongue_audio'] != new_tongue:
638
+ tongue_audio = new_tongue
639
+ app_state['last_audio_state']['tongue_audio'] = new_tongue
640
+ print(f"DEBUG continue: NEW tongue_audio: {new_tongue}")
641
+ elif app_state['last_audio_state']['tongue_audio'] is not None:
642
+ tongue_audio = app_state['last_audio_state']['tongue_audio']
643
+ else:
644
+ # Handle case where cached state is None - use the current file
645
+ tongue_audio = new_tongue
646
+ app_state['last_audio_state']['tongue_audio'] = new_tongue
647
+
648
+ # If a sound was just added, make sure it gets updated immediately (override state check)
649
+ if result['sound_added'] and predicted_name in sounds:
650
+ new_sound = sounds[predicted_name]
651
+ if predicted_name == 'left_hand':
652
+ left_hand_audio = new_sound
653
+ app_state['last_audio_state']['left_hand_audio'] = new_sound
654
+ elif predicted_name == 'right_hand':
655
+ right_hand_audio = new_sound
656
+ app_state['last_audio_state']['right_hand_audio'] = new_sound
657
+ elif predicted_name == 'left_leg':
658
+ left_leg_audio = new_sound
659
+ app_state['last_audio_state']['left_leg_audio'] = new_sound
660
+ elif predicted_name == 'right_leg':
661
+ right_leg_audio = new_sound
662
+ app_state['last_audio_state']['right_leg_audio'] = new_sound
663
+ elif predicted_name == 'tongue':
664
+ tongue_audio = new_sound
665
+ app_state['last_audio_state']['tongue_audio'] = new_sound
666
+ print(f"DEBUG continue: Force update - just added {predicted_name} sound: {new_sound}")
667
+
668
+ # Format display with progress information
669
+ completed_count = len(sound_manager.movements_completed)
670
+ total_count = len(sound_manager.current_movement_sequence)
671
+
672
+ if sound_manager.current_phase == "dj_effects":
673
+ target_text = f"🎧 DJ Mode - Control Effects with Movements"
674
+ predicted_text = f"🧠 Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
675
+ if result.get('effect_applied'):
676
+ effect_name = result.get('effect_name', '')
677
+ effect_status = result.get('effect_status', '')
678
+ timer_text = f"�️ {effect_name}: {effect_status}"
679
+ else:
680
+ timer_text = "🎡 Move to control effects..."
681
+ else:
682
+ target_text = f"�🎯 Any Movement ({completed_count}/{total_count} complete)"
683
+ predicted_text = f"🧠 Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
684
+ timer_text = "⏱️ Next trial in 2-3 seconds..." if app_state['auto_mode'] else "Stopped"
685
+
686
+ # Get composition info
687
+ composition_info = sound_manager.get_composition_info()
688
+ status_text = format_composition_summary(composition_info)
689
+
690
+ # Phase-based instruction visibility
691
+ building_visible = sound_manager.current_phase == "building"
692
+ dj_visible = sound_manager.current_phase == "dj_effects"
693
+
694
+ return (
695
+ target_text,
696
+ predicted_text,
697
+ timer_text,
698
+ prompt_text,
699
+ fig,
700
+ left_hand_audio,
701
+ right_hand_audio,
702
+ left_leg_audio,
703
+ right_leg_audio,
704
+ tongue_audio,
705
+ status_text,
706
+ gr.update(visible=building_visible), # building_instructions
707
+ gr.update(visible=dj_visible) # dj_instructions
708
+ )
709
+
710
+ def classify_epoch():
711
+ """Classify a single EEG epoch and update composition."""
712
+ global app_state
713
+
714
+ if not app_state['composition_active']:
715
+ return "❌ Not active", "❌ Not active", "❌ Not active", None, None, None, None, None, None, "Click 'Start Composing' first"
716
+
717
+ if app_state['demo_data'] is None:
718
+ return "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available"
719
+
720
+ # Get current target
721
+ target_movement = sound_manager.get_current_target_movement()
722
+ print(f"DEBUG classify_epoch: current target = {target_movement}")
723
+
724
+ if target_movement == "cycle_complete":
725
+ return "🎡 Cycle Complete!", "🎡 Complete", "Remap sounds to continue", None, None, None, None, None, None, "Cycle complete - remap sounds to continue"
726
+
727
+ # Get EEG data sample
728
+ epoch_data, true_label = data_processor.simulate_real_time_data(
729
+ app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
730
+ )
731
+
732
+ # Classify the epoch
733
+ predicted_class, confidence, probabilities = classifier.predict(epoch_data)
734
+ predicted_name = classifier.class_names[predicted_class]
735
+
736
+ # Process classification
737
+ result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
738
+
739
+ # Check if we should transition to DJ phase
740
+ completed_count = len(sound_manager.movements_completed)
741
+ if completed_count >= 5 and sound_manager.current_phase == "building":
742
+ if sound_manager.transition_to_dj_phase():
743
+ print(f"DEBUG: Successfully transitioned to DJ phase with {completed_count} completed movements")
744
+
745
+ # Create visualization
746
+ fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
747
+
748
+ # Initialize all audio components to silent (to clear any previous sounds)
749
+ silent_file = "silent.wav"
750
+ left_hand_audio = silent_file
751
+ right_hand_audio = silent_file
752
+ left_leg_audio = silent_file
753
+ right_leg_audio = silent_file
754
+ tongue_audio = silent_file
755
+
756
+ # Only play the sound if it was just added and matches the prediction
757
+ if result['sound_added']:
758
+ sounds = get_movement_sounds()
759
+ if predicted_name == 'left_hand' and 'left_hand' in sounds:
760
+ left_hand_audio = sounds['left_hand']
761
+ elif predicted_name == 'right_hand' and 'right_hand' in sounds:
762
+ right_hand_audio = sounds['right_hand']
763
+ elif predicted_name == 'left_leg' and 'left_leg' in sounds:
764
+ left_leg_audio = sounds['left_leg']
765
+ elif predicted_name == 'right_leg' and 'right_leg' in sounds:
766
+ right_leg_audio = sounds['right_leg']
767
+ elif predicted_name == 'tongue' and 'tongue' in sounds:
768
+ tongue_audio = sounds['tongue']
769
+
770
+ # Format next target
771
+ next_target = sound_manager.get_current_target_movement()
772
+ target_text = f"🎯 Target: {next_target.replace('_', ' ').title()}" if next_target != "cycle_complete" else "🎡 Cycle Complete!"
773
+
774
+ predicted_text = f"🧠 Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
775
+
776
+ # Get composition info
777
+ composition_info = sound_manager.get_composition_info()
778
+ status_text = format_composition_summary(composition_info)
779
+
780
+ return (
781
+ target_text,
782
+ predicted_text,
783
+ "2-3 seconds",
784
+ fig,
785
+ left_hand_audio,
786
+ right_hand_audio,
787
+ left_leg_audio,
788
+ right_leg_audio,
789
+ tongue_audio,
790
+ status_text
791
+ )
792
+
793
+ def create_eeg_plot(eeg_data: np.ndarray, target_movement: str, predicted_name: str, confidence: float, sound_added: bool) -> plt.Figure:
794
+ """Create EEG plot with target movement and classification result."""
795
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
796
+ axes = axes.flatten()
797
+
798
+ # Plot 4 channels
799
+ time_points = np.arange(eeg_data.shape[1]) / 200 # 200 Hz sampling rate
800
+ channel_names = ['C3', 'C4', 'T3', 'T4'] # Motor cortex channels
801
+
802
+ for i in range(min(4, eeg_data.shape[0])):
803
+ color = 'green' if sound_added else 'blue'
804
+ axes[i].plot(time_points, eeg_data[i], color=color, linewidth=1)
805
+
806
+ if i < len(channel_names):
807
+ axes[i].set_title(f'{channel_names[i]} (Ch {i+1})')
808
+ else:
809
+ axes[i].set_title(f'Channel {i+1}')
810
+
811
+ axes[i].set_xlabel('Time (s)')
812
+ axes[i].set_ylabel('Amplitude (Β΅V)')
813
+ axes[i].grid(True, alpha=0.3)
814
+
815
+ # Add overall title with status
816
+ status = "βœ“ SOUND ADDED" if sound_added else "β—‹ No sound"
817
+ title = f"Target: {target_movement.replace('_', ' ').title()} | Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f}) | {status}"
818
+ fig.suptitle(title, fontsize=12, fontweight='bold')
819
+ fig.tight_layout()
820
+ return fig
821
+
822
+ def format_composition_summary(composition_info: Dict) -> str:
823
+ """Format composition information for display."""
824
+ if not composition_info.get('layers_by_cycle'):
825
+ return "No composition layers yet"
826
+
827
+ summary = []
828
+ for cycle, layers in composition_info['layers_by_cycle'].items():
829
+ summary.append(f"Cycle {cycle + 1}: {len(layers)} layers")
830
+ for layer in layers:
831
+ movement = layer.get('movement', 'unknown')
832
+ confidence = layer.get('confidence', 0)
833
+ summary.append(f" β€’ {movement.replace('_', ' ').title()} ({confidence:.2f})")
834
+
835
+ return "\n".join(summary) if summary else "No composition layers"
836
+
837
+ # Create Gradio interface
838
+ def create_interface():
839
+ with gr.Blocks(title="EEG Motor Imagery Music Composer", theme=gr.themes.Soft()) as demo:
840
+ gr.Markdown("# 🧠🎡 EEG Motor Imagery Rehabilitation Composer")
841
+ gr.Markdown("**Therapeutic Brain-Computer Interface for Motor Recovery**\n\nCreate beautiful music compositions using your brain signals! This rehabilitation tool helps strengthen motor imagery skills while creating personalized musical pieces.")
842
+
843
+ with gr.Tabs() as tabs:
844
+ # Main Composition Tab
845
+ with gr.TabItem("🎡 Automatic Composition"):
846
+ with gr.Row():
847
+ # Left side - Task and EEG information
848
+ with gr.Column(scale=2):
849
+ # Task instructions - Building Phase
850
+ with gr.Group() as building_instructions:
851
+ gr.Markdown("### 🎯 Rehabilitation Session Instructions")
852
+ gr.Markdown("""
853
+ **Motor Imagery Training:**
854
+ - **Imagine** opening or closing your **right or left hand**
855
+ - **Visualize** briefly moving your **right or left leg or foot**
856
+ - **Think about** pronouncing **"L"** with your tongue
857
+ - **Rest state** (no movement imagination)
858
+
859
+ *🌟 Each successful imagination creates a musical layer!*
860
+
861
+ **Session Structure:** Build composition, then control DJ effects
862
+ *Press Start to begin your personalized rehabilitation session*
863
+ """)
864
+
865
+ # DJ Instructions - Effects Phase (initially hidden)
866
+ with gr.Group(visible=False) as dj_instructions:
867
+ gr.Markdown("### 🎧 DJ Controller Mode")
868
+ gr.Markdown("""
869
+ **πŸŽ‰ Composition Complete! You are now the DJ!**
870
+
871
+ **Use the same movements to control audio effects:**
872
+ - πŸ‘ˆ **Left Hand**: Volume Fade On/Off
873
+ - πŸ‘‰ **Right Hand**: High Pass Filter On/Off
874
+ - 🦡 **Left Leg**: Reverb Effect On/Off
875
+ - 🦡 **Right Leg**: Low Pass Filter On/Off
876
+ - πŸ‘… **Tongue**: Bass Boost On/Off
877
+
878
+ *πŸŽ›οΈ Each movement toggles an effect - Mix your creation!*
879
+ """)
880
+
881
+ # Start button
882
+ with gr.Row():
883
+ start_btn = gr.Button("🎡 Start Composing", variant="primary", size="lg")
884
+ continue_btn = gr.Button("⏭️ Continue", variant="primary", size="lg", visible=False)
885
+ stop_btn = gr.Button("πŸ›‘ Stop", variant="secondary", size="lg")
886
+
887
+ # Session completion options (shown after 2 cycles)
888
+ with gr.Row(visible=False) as session_complete_row:
889
+ new_session_btn = gr.Button("πŸ”„ Start New Session", variant="primary", size="lg")
890
+ extend_session_btn = gr.Button("βž• Continue Session", variant="secondary", size="lg")
891
+
892
+ # Timer for automatic progression (hidden from user)
893
+ timer = gr.Timer(value=3.0, active=False) # 3 second intervals
894
+
895
+ # User prompt display
896
+ user_prompt = gr.Textbox(label="πŸ’­ User Prompt", interactive=False, value="Click 'Start Composing' to begin",
897
+ elem_classes=["prompt-display"])
898
+
899
+ # Current status
900
+ with gr.Row():
901
+ target_display = gr.Textbox(label="🎯 Current Target", interactive=False, value="Ready to start")
902
+ predicted_display = gr.Textbox(label="🧠 Predicted", interactive=False, value="--")
903
+
904
+ timer_display = gr.Textbox(label="⏱️ Next Trial In", interactive=False, value="--")
905
+
906
+ eeg_plot = gr.Plot(label="EEG Data Visualization")
907
+
908
+ # Right side - Compositional layers
909
+ with gr.Column(scale=1):
910
+ gr.Markdown("### 🎡 Compositional Layers")
911
+
912
+ # Show 5 movement sounds
913
+ left_hand_sound = gr.Audio(label="πŸ‘ˆ Left Hand", interactive=False, autoplay=True, visible=True)
914
+ right_hand_sound = gr.Audio(label="πŸ‘‰ Right Hand", interactive=False, autoplay=True, visible=True)
915
+ left_leg_sound = gr.Audio(label="🦡 Left Leg", interactive=False, autoplay=True, visible=True)
916
+ right_leg_sound = gr.Audio(label="🦡 Right Leg", interactive=False, autoplay=True, visible=True)
917
+ tongue_sound = gr.Audio(label="πŸ‘… Tongue", interactive=False, autoplay=True, visible=True)
918
+
919
+ # Composition status
920
+ composition_status = gr.Textbox(label="Composition Status", interactive=False, lines=5)
921
+
922
+ # Manual Testing Tab
923
+ with gr.TabItem("🧠 Manual Testing"):
924
+ with gr.Row():
925
+ with gr.Column(scale=2):
926
+ gr.Markdown("### πŸ”¬ Manual EEG Classification Testing")
927
+ gr.Markdown("Use this tab to manually test the EEG classifier without the composition system.")
928
+
929
+ with gr.Row():
930
+ classify_btn = gr.Button("🧠 Classify Single Epoch", variant="primary")
931
+ clear_btn = gr.Button("�️ Clear", variant="secondary")
932
+
933
+ # Manual status displays
934
+ manual_target_display = gr.Textbox(label="🎯 Current Target", interactive=False, value="Ready")
935
+ manual_predicted_display = gr.Textbox(label="🧠 Predicted", interactive=False, value="--")
936
+ manual_timer_display = gr.Textbox(label="⏱️ Status", interactive=False, value="Manual mode")
937
+
938
+ manual_eeg_plot = gr.Plot(label="EEG Data Visualization")
939
+
940
+ with gr.Column(scale=1):
941
+ gr.Markdown("### πŸ“Š Classification Results")
942
+ manual_results = gr.Textbox(label="Results Log", interactive=False, lines=10, value="Manual classification results will appear here...")
943
+
944
+ # Individual sound previews (no autoplay in manual mode)
945
+ gr.Markdown("### πŸ”Š Sound Preview")
946
+ manual_left_hand_sound = gr.Audio(label="πŸ‘ˆ Left Hand", interactive=False, autoplay=False, visible=True)
947
+ manual_right_hand_sound = gr.Audio(label="πŸ‘‰ Right Hand", interactive=False, autoplay=False, visible=True)
948
+ manual_left_leg_sound = gr.Audio(label="🦡 Left Leg", interactive=False, autoplay=False, visible=True)
949
+ manual_right_leg_sound = gr.Audio(label="🦡 Right Leg", interactive=False, autoplay=False, visible=True)
950
+ manual_tongue_sound = gr.Audio(label="πŸ‘… Tongue", interactive=False, autoplay=False, visible=True)
951
+
952
+ # Session management functions
953
+ def start_new_session():
954
+ """Reset everything and start a completely new rehabilitation session"""
955
+ global sound_manager
956
+ sound_manager.completed_cycles = 0
957
+ sound_manager.current_cycle = 0
958
+ sound_manager.movements_completed = set()
959
+ sound_manager.composition_layers = []
960
+
961
+ # Start fresh session
962
+ result = start_automatic_composition()
963
+ return (
964
+ result[0], # target_display
965
+ result[1], # predicted_display
966
+ result[2], # timer_display
967
+ result[3], # user_prompt
968
+ result[4], # eeg_plot
969
+ result[5], # left_hand_sound
970
+ result[6], # right_hand_sound
971
+ result[7], # left_leg_sound
972
+ result[8], # right_leg_sound
973
+ result[9], # tongue_sound
974
+ result[10], # composition_status
975
+ result[11], # building_instructions
976
+ result[12], # dj_instructions
977
+ gr.update(visible=True), # continue_btn - show it
978
+ gr.update(active=True), # timer - activate it
979
+ gr.update(visible=False) # session_complete_row - hide it
980
+ )
981
+
982
+ def extend_current_session():
983
+ """Continue current session beyond the 2-cycle limit"""
984
+ sound_manager.max_cycles += 2 # Add 2 more cycles
985
+
986
+ # Continue with current session
987
+ result = continue_automatic_composition()
988
+ return (
989
+ result[0], # target_display
990
+ result[1], # predicted_display
991
+ result[2], # timer_display
992
+ result[3], # user_prompt
993
+ result[4], # eeg_plot
994
+ result[5], # left_hand_sound
995
+ result[6], # right_hand_sound
996
+ result[7], # left_leg_sound
997
+ result[8], # right_leg_sound
998
+ result[9], # tongue_sound
999
+ result[10], # composition_status
1000
+ result[11], # building_instructions
1001
+ result[12], # dj_instructions
1002
+ gr.update(visible=True), # continue_btn - show it
1003
+ gr.update(active=True), # timer - activate it
1004
+ gr.update(visible=False) # session_complete_row - hide it
1005
+ )
1006
+
1007
+ # Wrapper functions for timer control
1008
+ def start_with_timer():
1009
+ """Start composition and activate automatic timer"""
1010
+ result = start_automatic_composition()
1011
+ # Show continue button and activate timer
1012
+ return (
1013
+ result[0], # target_display
1014
+ result[1], # predicted_display
1015
+ result[2], # timer_display
1016
+ result[3], # user_prompt
1017
+ result[4], # eeg_plot
1018
+ result[5], # left_hand_sound
1019
+ result[6], # right_hand_sound
1020
+ result[7], # left_leg_sound
1021
+ result[8], # right_leg_sound
1022
+ result[9], # tongue_sound
1023
+ result[10], # composition_status
1024
+ result[11], # building_instructions
1025
+ result[12], # dj_instructions
1026
+ gr.update(visible=True), # continue_btn - show it
1027
+ gr.update(active=True) # timer - activate it
1028
+ )
1029
+
1030
+ def continue_with_timer():
1031
+ """Continue composition and manage timer state"""
1032
+ result = continue_automatic_composition()
1033
+
1034
+ # Check if session is complete (rehabilitation session finished)
1035
+ if "πŸŽ‰ Session Complete!" in result[0]:
1036
+ # Show session completion options
1037
+ return (
1038
+ result[0], # target_display
1039
+ result[1], # predicted_display
1040
+ result[2], # timer_display
1041
+ result[3], # user_prompt
1042
+ result[4], # eeg_plot
1043
+ result[5], # left_hand_sound
1044
+ result[6], # right_hand_sound
1045
+ result[7], # left_leg_sound
1046
+ result[8], # right_leg_sound
1047
+ result[9], # tongue_sound
1048
+ result[10], # composition_status
1049
+ result[11], # building_instructions
1050
+ result[12], # dj_instructions
1051
+ gr.update(active=False), # timer - deactivate it
1052
+ gr.update(visible=True) # session_complete_row - show options
1053
+ )
1054
+ # Check if composition is complete (old logic for other cases)
1055
+ elif "🎡 Cycle Complete!" in result[0]:
1056
+ # Stop the timer when composition is complete
1057
+ return (
1058
+ result[0], # target_display
1059
+ result[1], # predicted_display
1060
+ result[2], # timer_display
1061
+ result[3], # user_prompt
1062
+ result[4], # eeg_plot
1063
+ result[5], # left_hand_sound
1064
+ result[6], # right_hand_sound
1065
+ result[7], # left_leg_sound
1066
+ result[8], # right_leg_sound
1067
+ result[9], # tongue_sound
1068
+ result[10], # composition_status
1069
+ result[11], # building_instructions
1070
+ result[12], # dj_instructions
1071
+ gr.update(active=False), # timer - deactivate it
1072
+ gr.update(visible=False) # session_complete_row - keep hidden
1073
+ )
1074
+ else:
1075
+ # Keep timer active for next iteration
1076
+ return (
1077
+ result[0], # target_display
1078
+ result[1], # predicted_display
1079
+ result[2], # timer_display
1080
+ result[3], # user_prompt
1081
+ result[4], # eeg_plot
1082
+ result[5], # left_hand_sound
1083
+ result[6], # right_hand_sound
1084
+ result[7], # left_leg_sound
1085
+ result[8], # right_leg_sound
1086
+ result[9], # tongue_sound
1087
+ result[10], # composition_status
1088
+ result[11], # building_instructions
1089
+ result[12], # dj_instructions
1090
+ gr.update(active=True), # timer - keep active
1091
+ gr.update(visible=False) # session_complete_row - keep hidden
1092
+ )
1093
+
1094
+ # Event handlers for automatic composition tab
1095
+ start_btn.click(
1096
+ fn=start_with_timer,
1097
+ outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
1098
+ left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, tongue_sound, composition_status,
1099
+ building_instructions, dj_instructions, continue_btn, timer]
1100
+ )
1101
+
1102
+ continue_btn.click(
1103
+ fn=continue_with_timer,
1104
+ outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
1105
+ left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, tongue_sound, composition_status,
1106
+ building_instructions, dj_instructions, timer, session_complete_row]
1107
+ )
1108
+
1109
+ # Timer automatically triggers continuation
1110
+ timer.tick(
1111
+ fn=continue_with_timer,
1112
+ outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
1113
+ left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, tongue_sound, composition_status,
1114
+ building_instructions, dj_instructions, timer, session_complete_row]
1115
+ )
1116
+
1117
+ # Session completion event handlers
1118
+ new_session_btn.click(
1119
+ fn=start_new_session,
1120
+ outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
1121
+ left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, tongue_sound, composition_status,
1122
+ building_instructions, dj_instructions, continue_btn, timer, session_complete_row]
1123
+ )
1124
+
1125
+ extend_session_btn.click(
1126
+ fn=extend_current_session,
1127
+ outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
1128
+ left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, tongue_sound, composition_status,
1129
+ building_instructions, dj_instructions, continue_btn, timer, session_complete_row]
1130
+ )
1131
+
1132
+ def stop_with_timer():
1133
+ """Stop composition and deactivate timer"""
1134
+ result = stop_composition()
1135
+ return (
1136
+ result[0], # target_display
1137
+ result[1], # predicted_display
1138
+ result[2], # timer_display
1139
+ result[3], # user_prompt
1140
+ gr.update(visible=False), # continue_btn - hide it
1141
+ gr.update(active=False) # timer - deactivate it
1142
+ )
1143
+
1144
+ stop_btn.click(
1145
+ fn=stop_with_timer,
1146
+ outputs=[target_display, predicted_display, timer_display, user_prompt, continue_btn, timer]
1147
+ )
1148
+
1149
+ # Event handlers for manual testing tab
1150
+ classify_btn.click(
1151
+ fn=manual_classify,
1152
+ outputs=[manual_target_display, manual_predicted_display, manual_timer_display, manual_eeg_plot, manual_results,
1153
+ manual_left_hand_sound, manual_right_hand_sound, manual_left_leg_sound, manual_right_leg_sound, manual_tongue_sound]
1154
+ )
1155
+
1156
+ clear_btn.click(
1157
+ fn=clear_manual,
1158
+ outputs=[manual_target_display, manual_predicted_display, manual_timer_display, manual_eeg_plot, manual_results,
1159
+ manual_left_hand_sound, manual_right_hand_sound, manual_left_leg_sound, manual_right_leg_sound, manual_tongue_sound]
1160
+ )
1161
+
1162
+ # Note: No auto-loading of sounds to prevent playing all sounds on startup
1163
+
1164
+ return demo
1165
+
1166
+ if __name__ == "__main__":
1167
+ demo = create_interface()
1168
+ demo.launch(server_name="0.0.0.0", server_port=7867)
classifier.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EEG Motor Imagery Classifier Module
3
+ ----------------------------------
4
+ Handles model loading, inference, and real-time prediction for motor imagery classification.
5
+ Based on the ShallowFBCSPNet architecture from the original eeg_motor_imagery.py script.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ from braindecode.models import ShallowFBCSPNet
12
+ from typing import Dict, Tuple, Optional
13
+ import os
14
+ from sklearn.metrics import accuracy_score
15
+ from data_processor import EEGDataProcessor
16
+ from config import DEMO_DATA_PATHS
17
+
18
+ class MotorImageryClassifier:
19
+ """
20
+ Motor imagery classifier using ShallowFBCSPNet model.
21
+ """
22
+
23
+ def __init__(self, model_path: str = "shallow_weights_all.pth"):
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ self.model = None
26
+ self.model_path = model_path
27
+ self.class_names = {
28
+ 0: "left_hand",
29
+ 1: "right_hand",
30
+ 2: "neutral",
31
+ 3: "left_leg",
32
+ 4: "tongue",
33
+ 5: "right_leg"
34
+ }
35
+ self.is_loaded = False
36
+
37
+ def load_model(self, n_chans: int, n_times: int, n_outputs: int = 6):
38
+ """Load the pre-trained ShallowFBCSPNet model."""
39
+ try:
40
+ self.model = ShallowFBCSPNet(
41
+ n_chans=n_chans,
42
+ n_outputs=n_outputs,
43
+ n_times=n_times,
44
+ final_conv_length="auto"
45
+ ).to(self.device)
46
+
47
+ if os.path.exists(self.model_path):
48
+ try:
49
+ state_dict = torch.load(self.model_path, map_location=self.device)
50
+ self.model.load_state_dict(state_dict)
51
+ self.model.eval()
52
+ self.is_loaded = True
53
+ print(f"βœ… Pre-trained model loaded successfully from {self.model_path}")
54
+ except Exception as model_error:
55
+ print(f"⚠️ Pre-trained model found but incompatible: {model_error}")
56
+ print("πŸ”„ Starting LOSO training with available EEG data...")
57
+ self.is_loaded = False
58
+ else:
59
+ print(f"❌ Pre-trained model weights not found at {self.model_path}")
60
+ print("πŸ”„ Starting LOSO training with available EEG data...")
61
+ self.is_loaded = False
62
+
63
+ except Exception as e:
64
+ print(f"❌ Error loading model: {e}")
65
+ print("πŸ”„ Starting LOSO training with available EEG data...")
66
+ self.is_loaded = False
67
+
68
+ def get_model_status(self) -> str:
69
+ """Get current model status for user interface."""
70
+ if self.is_loaded:
71
+ return "βœ… Pre-trained model loaded and ready"
72
+ else:
73
+ return "πŸ”„ Using LOSO training (training new model from EEG data)"
74
+
75
+ def predict(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]:
76
+ """
77
+ Predict motor imagery class from EEG data.
78
+
79
+ Args:
80
+ eeg_data: EEG data array of shape (n_channels, n_times)
81
+
82
+ Returns:
83
+ predicted_class: Predicted class index
84
+ confidence: Confidence score
85
+ probabilities: Dictionary of class probabilities
86
+ """
87
+ if not self.is_loaded:
88
+ return self._fallback_loso_classification(eeg_data)
89
+
90
+ # Ensure input is the right shape: (batch, channels, time)
91
+ if eeg_data.ndim == 2:
92
+ eeg_data = eeg_data[np.newaxis, ...]
93
+
94
+ # Convert to tensor
95
+ x = torch.from_numpy(eeg_data.astype(np.float32)).to(self.device)
96
+
97
+ with torch.no_grad():
98
+ output = self.model(x)
99
+ probabilities = torch.softmax(output, dim=1)
100
+ predicted_class = torch.argmax(probabilities, dim=1).cpu().numpy()[0]
101
+ confidence = probabilities.max().cpu().numpy()
102
+
103
+ # Convert to dictionary
104
+ prob_dict = {
105
+ self.class_names[i]: probabilities[0, i].cpu().numpy()
106
+ for i in range(len(self.class_names))
107
+ }
108
+
109
+ return predicted_class, confidence, prob_dict
110
+
111
+ def _fallback_loso_classification(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]:
112
+ """
113
+ Fallback classification using LOSO (Leave-One-Session-Out) training.
114
+ Trains a model on available data when pre-trained model isn't available.
115
+ """
116
+ try:
117
+ print("πŸ”„ No pre-trained model available. Training new model using LOSO method...")
118
+ print("⏳ This may take a moment - training on real EEG data...")
119
+
120
+ # Initialize data processor
121
+ processor = EEGDataProcessor()
122
+
123
+ # Check if demo data files exist
124
+ available_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)]
125
+ if len(available_files) < 2:
126
+ raise ValueError(f"Not enough data files for LOSO training. Need at least 2 files, found {len(available_files)}. "
127
+ f"Available files: {available_files}")
128
+
129
+ # Perform LOSO split (using first session as test)
130
+ X_train, y_train, X_test, y_test, session_info = processor.prepare_loso_split(
131
+ available_files, test_session_idx=0
132
+ )
133
+
134
+ # Get data dimensions
135
+ n_chans = X_train.shape[1]
136
+ n_times = X_train.shape[2]
137
+
138
+ # Create and train model
139
+ self.model = ShallowFBCSPNet(
140
+ n_chans=n_chans,
141
+ n_outputs=6,
142
+ n_times=n_times,
143
+ final_conv_length="auto"
144
+ ).to(self.device)
145
+
146
+ # Simple training loop
147
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
148
+ criterion = nn.CrossEntropyLoss()
149
+
150
+ # Convert training data to tensors
151
+ X_train_tensor = torch.from_numpy(X_train).float().to(self.device)
152
+ y_train_tensor = torch.from_numpy(y_train).long().to(self.device)
153
+
154
+ # Quick training (just a few epochs for demo)
155
+ self.model.train()
156
+ for epoch in range(50):
157
+ optimizer.zero_grad()
158
+ outputs = self.model(X_train_tensor)
159
+ loss = criterion(outputs, y_train_tensor)
160
+ loss.backward()
161
+ optimizer.step()
162
+
163
+ if epoch % 5 == 0:
164
+ print(f"LOSO Training - Epoch {epoch}, Loss: {loss.item():.4f}")
165
+
166
+ # Switch to evaluation mode
167
+ self.model.eval()
168
+ self.is_loaded = True
169
+
170
+ print("βœ… LOSO model trained successfully! Ready for classification.")
171
+
172
+ # Now make prediction with the trained model
173
+ return self.predict(eeg_data)
174
+
175
+ except Exception as e:
176
+ print(f"Error in LOSO training: {e}")
177
+ raise RuntimeError(f"Failed to initialize classifier. Neither pre-trained model nor LOSO training succeeded: {e}")
178
+
179
+
config.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the EEG Motor Imagery Music Composer
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+
8
+ # Application settings
9
+ APP_NAME = "EEG Motor Imagery Music Composer"
10
+ VERSION = "1.0.0"
11
+
12
+ # Data paths
13
+ BASE_DIR = Path(__file__).parent
14
+ DATA_DIR = BASE_DIR / "data"
15
+ SOUND_DIR = BASE_DIR / "sounds"
16
+ MODEL_DIR = BASE_DIR
17
+
18
+ # Model settings
19
+ MODEL_PATH = MODEL_DIR / "shallow_weights_all.pth"
20
+ # Model architecture: Always uses ShallowFBCSPNet from braindecode
21
+ # If pre-trained weights not found, will train using LOSO on available data
22
+
23
+ # EEG Data settings
24
+ SAMPLING_RATE = 200 # Hz
25
+ EPOCH_DURATION = 1.5 # seconds
26
+ N_CHANNELS = 19 # without ground and reference 19 electrodes
27
+ N_CLASSES = 6 # or 4
28
+
29
+ # Classification settings
30
+ CONFIDENCE_THRESHOLD = 0.3 # Minimum confidence to add sound layer (lowered for testing)
31
+ MAX_COMPOSITION_LAYERS = 6 # Maximum layers in composition
32
+
33
+ # Sound settings
34
+ SOUND_MAPPING = {
35
+ "left_hand": "1_SoundHelix-Song-6_(Bass).wav",
36
+ "right_hand": "1_SoundHelix-Song-6_(Drums).wav",
37
+ "neutral": None, # No sound for neutral/rest state
38
+ "left_leg": "1_SoundHelix-Song-6_(Other).wav",
39
+ "tongue": "1_SoundHelix-Song-6_(Vocals).wav",
40
+ "right_leg": "1_SoundHelix-Song-6_(Bass).wav" # Can be remapped by user
41
+ }
42
+
43
+ # Motor imagery class names
44
+ CLASS_NAMES = {
45
+ 0: "left_hand",
46
+ 1: "right_hand",
47
+ 2: "neutral",
48
+ 3: "left_leg",
49
+ 4: "tongue",
50
+ 5: "right_leg"
51
+ }
52
+
53
+ CLASS_DESCRIPTIONS = {
54
+ "left_hand": "🀚 Left Hand Movement",
55
+ "right_hand": "🀚 Right Hand Movement",
56
+ "neutral": "😐 Neutral/Rest State",
57
+ "left_leg": "🦡 Left Leg Movement",
58
+ "tongue": "πŸ‘… Tongue Movement",
59
+ "right_leg": "🦡 Right Leg Movement"
60
+ }
61
+
62
+ # Demo data paths (optional) - updated with available files
63
+ DEMO_DATA_PATHS = [
64
+ "data/raw_mat/HaLTSubjectA1602236StLRHandLegTongue.mat",
65
+ "data/raw_mat/HaLTSubjectA1603086StLRHandLegTongue.mat",
66
+ "data/raw_mat/HaLTSubjectA1603106StLRHandLegTongue.mat",
67
+ ]
68
+
69
+ # Gradio settings
70
+ GRADIO_PORT = 7860
71
+ GRADIO_HOST = "0.0.0.0"
72
+ GRADIO_SHARE = False # Set to True to create public links
73
+
74
+ # Logging settings
75
+ LOG_LEVEL = "INFO"
76
+ LOG_FILE = BASE_DIR / "logs" / "app.log"
77
+
78
+ # Create necessary directories
79
+ def create_directories():
80
+ """Create necessary directories if they don't exist."""
81
+ directories = [
82
+ DATA_DIR,
83
+ SOUND_DIR,
84
+ MODEL_DIR,
85
+ LOG_FILE.parent
86
+ ]
87
+
88
+ for directory in directories:
89
+ directory.mkdir(parents=True, exist_ok=True)
90
+
91
+ if __name__ == "__main__":
92
+ create_directories()
93
+ print("Configuration directories created successfully!")
data_processor.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EEG Data Processing Module
3
+ -------------------------
4
+ Handles EEG data loading, preprocessing, and epoching for real-time classification.
5
+ Adapted from the original eeg_motor_imagery.py script.
6
+ """
7
+
8
+ import scipy.io
9
+ import numpy as np
10
+ import mne
11
+ import torch
12
+ import pandas as pd
13
+ from typing import List, Tuple, Dict, Optional
14
+ from pathlib import Path
15
+ from scipy.signal import butter, lfilter
16
+
17
+ class EEGDataProcessor:
18
+ """
19
+ Processes EEG data from .mat files for motor imagery classification.
20
+ """
21
+
22
+ def __init__(self):
23
+ self.fs = None
24
+ self.ch_names = None
25
+ self.event_id = {
26
+ "left_hand": 1,
27
+ "right_hand": 2,
28
+ "neutral": 3,
29
+ "left_leg": 4,
30
+ "tongue": 5,
31
+ "right_leg": 6,
32
+ }
33
+
34
+ def load_mat_file(self, file_path: str) -> Tuple[np.ndarray, np.ndarray, List[str], int]:
35
+ """Load and parse a single .mat EEG file."""
36
+ mat = scipy.io.loadmat(file_path)
37
+ content = mat['o'][0, 0]
38
+
39
+ labels = content[4].flatten()
40
+ signals = content[5]
41
+ chan_names_raw = content[6]
42
+ channels = [ch[0][0] for ch in chan_names_raw]
43
+ fs = int(content[2][0, 0])
44
+
45
+ return signals, labels, channels, fs
46
+
47
+ def create_raw_object(self, signals: np.ndarray, channels: List[str], fs: int,
48
+ drop_ground_electrodes: bool = True) -> mne.io.RawArray:
49
+ """Create MNE Raw object from signal data."""
50
+ df = pd.DataFrame(signals, columns=channels)
51
+
52
+ if drop_ground_electrodes:
53
+ # Drop auxiliary channels that should be excluded
54
+ aux_exclude = ('X3', 'X5')
55
+ columns_to_drop = [ch for ch in channels if ch in aux_exclude]
56
+
57
+ df = df.drop(columns=columns_to_drop, errors="ignore")
58
+ print(f"Dropped auxiliary channels {columns_to_drop}. Remaining channels: {len(df.columns)}")
59
+
60
+ eeg = df.values.T
61
+ ch_names = df.columns.tolist()
62
+
63
+ self.ch_names = ch_names
64
+ self.fs = fs
65
+
66
+ info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types="eeg")
67
+ raw = mne.io.RawArray(eeg, info)
68
+
69
+ return raw
70
+
71
+ def extract_events(self, labels: np.ndarray) -> np.ndarray:
72
+ """Extract events from label array."""
73
+ onsets = np.where((labels[1:] != 0) & (labels[:-1] == 0))[0] + 1
74
+ event_codes = labels[onsets].astype(int)
75
+ events = np.c_[onsets, np.zeros_like(onsets), event_codes]
76
+
77
+ # Keep only relevant events
78
+ mask = np.isin(events[:, 2], np.arange(1, 7))
79
+ events = events[mask]
80
+
81
+ return events
82
+
83
+ def create_epochs(self, raw: mne.io.RawArray, events: np.ndarray,
84
+ tmin: float = 0, tmax: float = 1.5) -> mne.Epochs:
85
+ """Create epochs from raw data and events."""
86
+ epochs = mne.Epochs(
87
+ raw,
88
+ events=events,
89
+ event_id=self.event_id,
90
+ tmin=tmin,
91
+ tmax=tmax,
92
+ baseline=None,
93
+ preload=True,
94
+ )
95
+
96
+ return epochs
97
+
98
+ def process_files(self, file_paths: List[str]) -> Tuple[np.ndarray, np.ndarray]:
99
+ """Process multiple EEG files and return combined data."""
100
+ all_epochs = []
101
+
102
+ for file_path in file_paths:
103
+ signals, labels, channels, fs = self.load_mat_file(file_path)
104
+ raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
105
+ events = self.extract_events(labels)
106
+ epochs = self.create_epochs(raw, events)
107
+ all_epochs.append(epochs)
108
+
109
+ if len(all_epochs) > 1:
110
+ epochs_combined = mne.concatenate_epochs(all_epochs)
111
+ else:
112
+ epochs_combined = all_epochs[0]
113
+
114
+ # Convert to arrays for model input
115
+ X = epochs_combined.get_data().astype("float32")
116
+ y = (epochs_combined.events[:, -1] - 1).astype("int64") # classes 0..5
117
+
118
+ return X, y
119
+
120
+ def load_continuous_data(self, file_paths: List[str]) -> Tuple[np.ndarray, int]:
121
+ """
122
+ Load continuous raw EEG data without epoching.
123
+
124
+ Args:
125
+ file_paths: List of .mat file paths
126
+
127
+ Returns:
128
+ raw_data: Continuous EEG data [n_channels, n_timepoints]
129
+ fs: Sampling frequency
130
+ """
131
+ all_raw_data = []
132
+
133
+ for file_path in file_paths:
134
+ signals, labels, channels, fs = self.load_mat_file(file_path)
135
+ raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
136
+
137
+ # Extract continuous data (no epoching)
138
+ continuous_data = raw.get_data() # [n_channels, n_timepoints]
139
+ all_raw_data.append(continuous_data)
140
+
141
+ # Concatenate all continuous data along time axis
142
+ if len(all_raw_data) > 1:
143
+ combined_raw = np.concatenate(all_raw_data, axis=1)
144
+ else:
145
+ combined_raw = all_raw_data[0]
146
+
147
+ return combined_raw, fs
148
+
149
+ def prepare_loso_split(self, file_paths: List[str], test_subject_idx: int = 0) -> Tuple:
150
+ """
151
+ Prepare Leave-One-Subject-Out (LOSO) split for EEG data.
152
+
153
+ Args:
154
+ file_paths: List of .mat file paths (one per subject)
155
+ test_subject_idx: Index of subject to use for testing
156
+
157
+ Returns:
158
+ X_train, y_train, X_test, y_test, subject_info
159
+ """
160
+ all_subjects_data = []
161
+ subject_info = []
162
+
163
+ # Load each subject separately
164
+ for i, file_path in enumerate(file_paths):
165
+ signals, labels, channels, fs = self.load_mat_file(file_path)
166
+ raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
167
+ events = self.extract_events(labels)
168
+ epochs = self.create_epochs(raw, events)
169
+
170
+ # Convert to arrays
171
+ X_subject = epochs.get_data().astype("float32")
172
+ y_subject = (epochs.events[:, -1] - 1).astype("int64")
173
+
174
+ all_subjects_data.append((X_subject, y_subject))
175
+ subject_info.append({
176
+ 'file_path': file_path,
177
+ 'subject_id': f"Subject_{i+1}",
178
+ 'n_epochs': len(X_subject),
179
+ 'channels': channels,
180
+ 'fs': fs
181
+ })
182
+
183
+ # LOSO split: one subject for test, others for train
184
+ test_subject = all_subjects_data[test_subject_idx]
185
+ train_subjects = [all_subjects_data[i] for i in range(len(all_subjects_data)) if i != test_subject_idx]
186
+
187
+ # Combine training subjects
188
+ if len(train_subjects) > 1:
189
+ X_train = np.concatenate([subj[0] for subj in train_subjects], axis=0)
190
+ y_train = np.concatenate([subj[1] for subj in train_subjects], axis=0)
191
+ else:
192
+ X_train, y_train = train_subjects[0]
193
+
194
+ X_test, y_test = test_subject
195
+
196
+ print("LOSO Split:")
197
+ print(f" Test Subject: {subject_info[test_subject_idx]['subject_id']} ({len(X_test)} epochs)")
198
+ print(f" Train Subjects: {len(train_subjects)} subjects ({len(X_train)} epochs)")
199
+
200
+ return X_train, y_train, X_test, y_test, subject_info
201
+
202
+ def simulate_real_time_data(self, X: np.ndarray, y: np.ndarray, mode: str = "random") -> Tuple[np.ndarray, int]:
203
+ """
204
+ Simulate real-time EEG data for demo purposes.
205
+
206
+ Args:
207
+ X: EEG data array (currently epoched data)
208
+ y: Labels array
209
+ mode: "random", "sequential", or "class_balanced"
210
+
211
+ Returns:
212
+ Single epoch and its true label
213
+ """
214
+ if mode == "random":
215
+ idx = np.random.randint(0, len(X))
216
+ elif mode == "sequential":
217
+ # Use a counter for sequential sampling (would need to store state)
218
+ idx = np.random.randint(0, len(X)) # Simplified for now
219
+ elif mode == "class_balanced":
220
+ # Sample ensuring we get different classes
221
+ available_classes = np.unique(y)
222
+ target_class = np.random.choice(available_classes)
223
+ class_indices = np.where(y == target_class)[0]
224
+ idx = np.random.choice(class_indices)
225
+ else:
226
+ idx = np.random.randint(0, len(X))
227
+
228
+ return X[idx], y[idx]
229
+
230
+ def simulate_continuous_stream(self, raw_data: np.ndarray, fs: int, window_size: float = 1.5) -> np.ndarray:
231
+ """
232
+ Simulate continuous EEG stream by extracting sliding windows from raw data.
233
+
234
+ Args:
235
+ raw_data: Continuous EEG data [n_channels, n_timepoints]
236
+ fs: Sampling frequency
237
+ window_size: Window size in seconds
238
+
239
+ Returns:
240
+ Single window of EEG data [n_channels, window_samples]
241
+ """
242
+ window_samples = int(window_size * fs) # e.g., 1.5s * 200Hz = 300 samples
243
+
244
+ # Ensure we don't go beyond the data
245
+ max_start = raw_data.shape[1] - window_samples
246
+ if max_start <= 0:
247
+ return raw_data # Return full data if too short
248
+
249
+ # Random starting point in the continuous stream
250
+ start_idx = np.random.randint(0, max_start)
251
+ end_idx = start_idx + window_samples
252
+
253
+ # Extract window
254
+ window = raw_data[:, start_idx:end_idx]
255
+
256
+ return window
257
+
demo.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Demo script for the EEG Motor Imagery Music Composer
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ from pathlib import Path
9
+
10
+ # Add the current directory to Python path
11
+ current_dir = Path(__file__).parent
12
+ sys.path.insert(0, str(current_dir))
13
+
14
+ from data_processor import EEGDataProcessor
15
+ from classifier import MotorImageryClassifier
16
+ from sound_library import SoundManager
17
+ from utils import setup_logging, create_classification_summary
18
+ import numpy as np
19
+
20
+ def run_demo():
21
+ """Run a simple demo of the system components."""
22
+ print("🧠 EEG Motor Imagery Music Composer - Demo")
23
+ print("=" * 50)
24
+
25
+ # Initialize components
26
+ print("Initializing components...")
27
+ data_processor = EEGDataProcessor()
28
+ classifier = MotorImageryClassifier()
29
+ sound_manager = SoundManager()
30
+
31
+ # Create mock data for demo
32
+ print("Creating mock EEG data...")
33
+ mock_data = np.random.randn(10, 32, 384).astype(np.float32) # 10 samples, 32 channels, 384 time points
34
+ mock_labels = np.random.randint(0, 6, 10)
35
+
36
+ # Initialize classifier
37
+ print("Loading classifier...")
38
+ classifier.load_model(n_chans=32, n_times=384)
39
+
40
+ print(f"Available sounds: {sound_manager.get_available_sounds()}")
41
+ print()
42
+
43
+ # Run classification demo
44
+ print("Running classification demo...")
45
+ print("-" * 30)
46
+
47
+ for i in range(5):
48
+ # Get random sample
49
+ sample_idx = np.random.randint(0, len(mock_data))
50
+ eeg_sample = mock_data[sample_idx]
51
+ true_label = mock_labels[sample_idx]
52
+
53
+ # Classify
54
+ predicted_class, confidence, probabilities = classifier.predict(eeg_sample)
55
+ predicted_name = classifier.class_names[predicted_class]
56
+ true_name = classifier.class_names[true_label]
57
+
58
+ # Create summary
59
+ summary = create_classification_summary(predicted_name, confidence, probabilities)
60
+
61
+ print(f"Sample {i+1}:")
62
+ print(f" True class: {true_name}")
63
+ print(f" Predicted: {summary['emoji']} {predicted_name} ({summary['confidence_percent']})")
64
+
65
+ # Add to composition if confidence is high
66
+ if confidence > 0.7 and predicted_name != 'neutral':
67
+ sound_manager.add_layer(predicted_name, confidence)
68
+ print(f" β™ͺ Added to composition: {predicted_name}")
69
+ else:
70
+ print(f" - Not added (low confidence or neutral)")
71
+
72
+ print()
73
+
74
+ # Show composition summary
75
+ composition_info = sound_manager.get_composition_info()
76
+ print("Final Composition:")
77
+ print("-" * 20)
78
+ if composition_info:
79
+ for i, layer in enumerate(composition_info, 1):
80
+ print(f"{i}. {layer['class']} (confidence: {layer['confidence']:.2f})")
81
+ else:
82
+ print("No composition layers created")
83
+
84
+ print()
85
+ print("Demo completed! 🎡")
86
+ print("To run the full Gradio interface, execute: python app.py")
87
+
88
+ if __name__ == "__main__":
89
+ try:
90
+ run_demo()
91
+ except KeyboardInterrupt:
92
+ print("\nDemo interrupted by user")
93
+ except Exception as e:
94
+ print(f"Error running demo: {e}")
95
+ import traceback
96
+ traceback.print_exc()
enhanced_utils.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced utilities incorporating useful functions from the original src/ templates
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
10
+ from typing import Dict, List, Tuple, Optional
11
+ import logging
12
+
13
+ def evaluate_model_performance(model, dataloader, device, class_names: List[str]) -> Dict:
14
+ """
15
+ Comprehensive model evaluation with metrics and visualizations.
16
+ Enhanced version of src/evaluate.py
17
+ """
18
+ model.eval()
19
+ all_preds = []
20
+ all_labels = []
21
+ all_probs = []
22
+
23
+ with torch.no_grad():
24
+ for inputs, labels in dataloader:
25
+ inputs = inputs.to(device)
26
+ labels = labels.to(device)
27
+
28
+ outputs = model(inputs)
29
+ probs = torch.softmax(outputs, dim=1)
30
+ _, preds = torch.max(outputs, 1)
31
+
32
+ all_preds.extend(preds.cpu().numpy())
33
+ all_labels.extend(labels.cpu().numpy())
34
+ all_probs.extend(probs.cpu().numpy())
35
+
36
+ # Calculate metrics
37
+ accuracy = accuracy_score(all_labels, all_preds)
38
+ cm = confusion_matrix(all_labels, all_preds)
39
+ report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
40
+
41
+ return {
42
+ 'accuracy': accuracy,
43
+ 'predictions': all_preds,
44
+ 'labels': all_labels,
45
+ 'probabilities': all_probs,
46
+ 'confusion_matrix': cm,
47
+ 'classification_report': report
48
+ }
49
+
50
+ def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], title: str = "Confusion Matrix") -> plt.Figure:
51
+ """
52
+ Plot confusion matrix with proper formatting.
53
+ Enhanced version from src/visualize.py
54
+ """
55
+ fig, ax = plt.subplots(figsize=(8, 6))
56
+
57
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
58
+ xticklabels=class_names, yticklabels=class_names, ax=ax)
59
+
60
+ ax.set_title(title)
61
+ ax.set_ylabel('True Label')
62
+ ax.set_xlabel('Predicted Label')
63
+
64
+ plt.tight_layout()
65
+ return fig
66
+
67
+ def plot_classification_probabilities(probabilities: np.ndarray, class_names: List[str],
68
+ sample_indices: Optional[List[int]] = None) -> plt.Figure:
69
+ """
70
+ Plot classification probabilities for selected samples.
71
+ """
72
+ if sample_indices is None:
73
+ sample_indices = list(range(min(10, len(probabilities))))
74
+
75
+ fig, ax = plt.subplots(figsize=(12, 6))
76
+
77
+ x = np.arange(len(class_names))
78
+ width = 0.8 / len(sample_indices)
79
+
80
+ for i, sample_idx in enumerate(sample_indices):
81
+ offset = (i - len(sample_indices)/2) * width
82
+ ax.bar(x + offset, probabilities[sample_idx], width,
83
+ label=f'Sample {sample_idx}', alpha=0.8)
84
+
85
+ ax.set_xlabel('Motor Imagery Classes')
86
+ ax.set_ylabel('Probability')
87
+ ax.set_title('Classification Probabilities')
88
+ ax.set_xticks(x)
89
+ ax.set_xticklabels(class_names, rotation=45)
90
+ ax.legend()
91
+ ax.grid(True, alpha=0.3)
92
+
93
+ plt.tight_layout()
94
+ return fig
95
+
96
+ def plot_training_history(history: Dict[str, List[float]]) -> plt.Figure:
97
+ """
98
+ Plot training history (loss and accuracy).
99
+ Enhanced version from src/visualize.py
100
+ """
101
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
102
+
103
+ # Plot accuracy
104
+ if 'train_accuracy' in history and 'val_accuracy' in history:
105
+ ax1.plot(history['train_accuracy'], label='Train', linewidth=2)
106
+ ax1.plot(history['val_accuracy'], label='Validation', linewidth=2)
107
+ ax1.set_title('Model Accuracy')
108
+ ax1.set_xlabel('Epoch')
109
+ ax1.set_ylabel('Accuracy')
110
+ ax1.legend()
111
+ ax1.grid(True, alpha=0.3)
112
+
113
+ # Plot loss
114
+ if 'train_loss' in history and 'val_loss' in history:
115
+ ax2.plot(history['train_loss'], label='Train', linewidth=2)
116
+ ax2.plot(history['val_loss'], label='Validation', linewidth=2)
117
+ ax2.set_title('Model Loss')
118
+ ax2.set_xlabel('Epoch')
119
+ ax2.set_ylabel('Loss')
120
+ ax2.legend()
121
+ ax2.grid(True, alpha=0.3)
122
+
123
+ plt.tight_layout()
124
+ return fig
125
+
126
+ def plot_eeg_channels(eeg_data: np.ndarray, channel_names: Optional[List[str]] = None,
127
+ sample_rate: int = 256, title: str = "EEG Channels") -> plt.Figure:
128
+ """
129
+ Plot multiple EEG channels.
130
+ Enhanced visualization for EEG data.
131
+ """
132
+ n_channels, n_samples = eeg_data.shape
133
+ time_axis = np.arange(n_samples) / sample_rate
134
+
135
+ # Determine subplot layout
136
+ n_rows = int(np.ceil(np.sqrt(n_channels)))
137
+ n_cols = int(np.ceil(n_channels / n_rows))
138
+
139
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10))
140
+ if n_channels == 1:
141
+ axes = [axes]
142
+ else:
143
+ axes = axes.flatten()
144
+
145
+ for i in range(n_channels):
146
+ ax = axes[i]
147
+ ax.plot(time_axis, eeg_data[i], 'b-', linewidth=1)
148
+
149
+ channel_name = channel_names[i] if channel_names else f'Channel {i+1}'
150
+ ax.set_title(channel_name)
151
+ ax.set_xlabel('Time (s)')
152
+ ax.set_ylabel('Amplitude')
153
+ ax.grid(True, alpha=0.3)
154
+
155
+ # Hide unused subplots
156
+ for i in range(n_channels, len(axes)):
157
+ axes[i].set_visible(False)
158
+
159
+ plt.suptitle(title)
160
+ plt.tight_layout()
161
+ return fig
162
+
163
+ class EarlyStopping:
164
+ """
165
+ Early stopping utility from src/types/index.py
166
+ """
167
+ def __init__(self, patience=7, min_delta=0, restore_best_weights=True, verbose=False):
168
+ self.patience = patience
169
+ self.min_delta = min_delta
170
+ self.restore_best_weights = restore_best_weights
171
+ self.verbose = verbose
172
+ self.best_loss = None
173
+ self.counter = 0
174
+ self.best_weights = None
175
+
176
+ def __call__(self, val_loss, model):
177
+ if self.best_loss is None:
178
+ self.best_loss = val_loss
179
+ self.save_checkpoint(model)
180
+ elif val_loss < self.best_loss - self.min_delta:
181
+ self.best_loss = val_loss
182
+ self.counter = 0
183
+ self.save_checkpoint(model)
184
+ else:
185
+ self.counter += 1
186
+
187
+ if self.counter >= self.patience:
188
+ if self.verbose:
189
+ print(f'Early stopping triggered after {self.counter} epochs of no improvement')
190
+ if self.restore_best_weights:
191
+ model.load_state_dict(self.best_weights)
192
+ return True
193
+ return False
194
+
195
+ def save_checkpoint(self, model):
196
+ """Save model when validation loss decreases."""
197
+ if self.restore_best_weights:
198
+ self.best_weights = model.state_dict().copy()
199
+
200
+ def create_enhanced_evaluation_report(model, test_loader, class_names: List[str],
201
+ device, save_plots: bool = True) -> Dict:
202
+ """
203
+ Create a comprehensive evaluation report with plots and metrics.
204
+ """
205
+ # Get evaluation results
206
+ results = evaluate_model_performance(model, test_loader, device, class_names)
207
+
208
+ # Create visualizations
209
+ plots = {}
210
+
211
+ # Confusion Matrix
212
+ plots['confusion_matrix'] = plot_confusion_matrix(
213
+ results['confusion_matrix'], class_names,
214
+ title="Motor Imagery Classification - Confusion Matrix"
215
+ )
216
+
217
+ # Classification Probabilities (sample)
218
+ plots['probabilities'] = plot_classification_probabilities(
219
+ np.array(results['probabilities']), class_names,
220
+ sample_indices=list(range(min(5, len(results['probabilities']))))
221
+ )
222
+
223
+ if save_plots:
224
+ for plot_name, fig in plots.items():
225
+ fig.savefig(f'{plot_name}.png', dpi=300, bbox_inches='tight')
226
+
227
+ return {
228
+ 'metrics': results,
229
+ 'plots': plots,
230
+ 'summary': {
231
+ 'accuracy': results['accuracy'],
232
+ 'n_samples': len(results['labels']),
233
+ 'n_classes': len(class_names),
234
+ 'class_names': class_names
235
+ }
236
+ }
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ pandas
3
+ numpy
4
+ scipy
5
+ matplotlib
6
+ torch
7
+ torchvision
8
+ mne
9
+ braindecode
10
+ scikit-learn
11
+ pydub
12
+ soundfile
13
+ librosa
14
+ threading
sound_library.py ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sound Management System for EEG Motor Imagery Classification
3
+ -----------------------------------------------------------
4
+ Handles sound mapping, layering, and music composition based on motor imagery predictions.
5
+ Uses local SoundHelix audio files for different motor imagery classes.
6
+ """
7
+
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import os
11
+ from typing import Dict, List, Optional, Tuple
12
+ import threading
13
+ import time
14
+ from pathlib import Path
15
+ import tempfile
16
+
17
+ # Professional audio effects processor using scipy and librosa
18
+ from scipy import signal
19
+ import librosa
20
+
21
+ class AudioEffectsProcessor:
22
+ """Professional audio effects for DJ mode using scipy and librosa."""
23
+
24
+ @staticmethod
25
+ def apply_volume_fade(data: np.ndarray, fade_type: str = "out", fade_length: float = 0.5) -> np.ndarray:
26
+ """Apply volume fade effect with linear fade in/out."""
27
+ try:
28
+ samples = len(data)
29
+ fade_samples = int(fade_length * samples)
30
+
31
+ if fade_type == "out":
32
+ # Fade out: linear decrease from 1.0 to 0.3
33
+ fade_curve = np.linspace(1.0, 0.3, fade_samples)
34
+ data[-fade_samples:] *= fade_curve
35
+ elif fade_type == "in":
36
+ # Fade in: linear increase from 0.3 to 1.0
37
+ fade_curve = np.linspace(0.3, 1.0, fade_samples)
38
+ data[:fade_samples] *= fade_curve
39
+
40
+ return data
41
+ except Exception as e:
42
+ print(f"Volume fade effect failed: {e}")
43
+ return data
44
+
45
+ @staticmethod
46
+ def apply_high_pass_filter(data: np.ndarray, samplerate: int, cutoff: float = 800.0) -> np.ndarray:
47
+ """Apply high-pass filter to emphasize highs and cut lows."""
48
+ try:
49
+ # Design butterworth high-pass filter
50
+ nyquist = samplerate / 2
51
+ normalized_cutoff = cutoff / nyquist
52
+ b, a = signal.butter(4, normalized_cutoff, btype='high', analog=False)
53
+
54
+ # Apply filter
55
+ filtered_data = signal.filtfilt(b, a, data)
56
+ return filtered_data
57
+ except Exception as e:
58
+ print(f"High-pass filter failed: {e}")
59
+ return data
60
+
61
+ @staticmethod
62
+ def apply_low_pass_filter(data: np.ndarray, samplerate: int, cutoff: float = 1200.0) -> np.ndarray:
63
+ """Apply low-pass filter to emphasize lows and cut highs."""
64
+ try:
65
+ # Design butterworth low-pass filter
66
+ nyquist = samplerate / 2
67
+ normalized_cutoff = cutoff / nyquist
68
+ b, a = signal.butter(4, normalized_cutoff, btype='low', analog=False)
69
+
70
+ # Apply filter
71
+ filtered_data = signal.filtfilt(b, a, data)
72
+ return filtered_data
73
+ except Exception as e:
74
+ print(f"Low-pass filter failed: {e}")
75
+ return data
76
+
77
+ @staticmethod
78
+ def apply_reverb(data: np.ndarray, samplerate: int, room_size: float = 0.5) -> np.ndarray:
79
+ """Apply simple reverb effect using delay and feedback."""
80
+ try:
81
+ # Simple reverb using multiple delayed copies
82
+ delay_samples = int(0.1 * samplerate) # 100ms delay
83
+ decay = 0.3 * room_size
84
+
85
+ # Create reverb buffer
86
+ reverb_data = np.copy(data)
87
+
88
+ # Add delayed copies with decay
89
+ for i in range(3):
90
+ delay = delay_samples * (i + 1)
91
+ if delay < len(data):
92
+ gain = decay ** (i + 1)
93
+ reverb_data[delay:] += data[:-delay] * gain
94
+
95
+ # Mix original with reverb
96
+ return 0.7 * data + 0.3 * reverb_data
97
+ except Exception as e:
98
+ print(f"Reverb effect failed: {e}")
99
+ return data
100
+
101
+ @staticmethod
102
+ def apply_bass_boost(data: np.ndarray, samplerate: int, boost_db: float = 6.0) -> np.ndarray:
103
+ """Apply bass boost using low-frequency shelving filter."""
104
+ try:
105
+ # Design low-shelf filter for bass boost
106
+ freq = 250.0 # Bass frequency cutoff
107
+ nyquist = samplerate / 2
108
+ normalized_freq = freq / nyquist
109
+
110
+ # Convert dB to linear gain
111
+ gain = 10 ** (boost_db / 20)
112
+
113
+ # Simple bass boost: amplify low frequencies
114
+ b, a = signal.butter(2, normalized_freq, btype='low', analog=False)
115
+ low_freq = signal.filtfilt(b, a, data)
116
+
117
+ # Mix boosted lows with original
118
+ boosted_data = data + (low_freq * (gain - 1))
119
+
120
+ # Normalize to prevent clipping
121
+ max_val = np.max(np.abs(boosted_data))
122
+ if max_val > 0.95:
123
+ boosted_data = boosted_data * 0.95 / max_val
124
+
125
+ return boosted_data
126
+ except Exception as e:
127
+ print(f"Bass boost failed: {e}")
128
+ return data
129
+
130
+ @staticmethod
131
+ def process_with_effects(audio_file: str, active_effects: Dict[str, bool]) -> str:
132
+ """Process audio file with active effects and return processed version."""
133
+ try:
134
+ # Check if audio_file is valid
135
+ if not audio_file or not os.path.exists(audio_file):
136
+ print(f"Invalid audio file: {audio_file}")
137
+ return audio_file
138
+
139
+ # Read audio file
140
+ data, samplerate = sf.read(audio_file)
141
+
142
+ # Handle stereo to mono conversion if needed
143
+ if len(data.shape) > 1:
144
+ data = np.mean(data, axis=1)
145
+
146
+ # Apply effects based on active states
147
+ processed_data = np.copy(data)
148
+ effect_names = []
149
+
150
+ if active_effects.get("left_hand", False): # Volume Fade
151
+ processed_data = AudioEffectsProcessor.apply_volume_fade(processed_data, "out")
152
+ effect_names.append("fade")
153
+
154
+ if active_effects.get("right_hand", False): # High Pass Filter
155
+ processed_data = AudioEffectsProcessor.apply_high_pass_filter(processed_data, samplerate)
156
+ effect_names.append("hpf")
157
+
158
+ if active_effects.get("left_leg", False): # Reverb
159
+ processed_data = AudioEffectsProcessor.apply_reverb(processed_data, samplerate)
160
+ effect_names.append("rev")
161
+
162
+ if active_effects.get("right_leg", False): # Low Pass Filter
163
+ processed_data = AudioEffectsProcessor.apply_low_pass_filter(processed_data, samplerate)
164
+ effect_names.append("lpf")
165
+
166
+ if active_effects.get("tongue", False): # Bass Boost
167
+ processed_data = AudioEffectsProcessor.apply_bass_boost(processed_data, samplerate)
168
+ effect_names.append("bass")
169
+
170
+ # Create unique filename based on active effects
171
+ base_name = os.path.splitext(audio_file)[0]
172
+ effects_suffix = "_".join(effect_names) if effect_names else "clean"
173
+ processed_file = f"{base_name}_fx_{effects_suffix}.wav"
174
+
175
+ # Save processed audio
176
+ # sf.write(processed_file, processed_data, samplerate)
177
+ print(f"πŸŽ›οΈ Audio processed with effects: {effects_suffix} (FILE SAVING DISABLED)")
178
+
179
+ # Return absolute path (using original file since saving is disabled)
180
+ return os.path.abspath(audio_file)
181
+
182
+ except Exception as e:
183
+ print(f"Audio processing failed: {e}")
184
+ return os.path.abspath(audio_file) if audio_file else None
185
+
186
+ class SoundManager:
187
+ """
188
+ Manages cyclic sound composition for motor imagery classification.
189
+ Supports full-cycle composition with user-customizable movement-sound mappings.
190
+ """
191
+
192
+ def __init__(self, sound_dir: str = "sounds", include_neutral_in_cycle: bool = False):
193
+ self.sound_dir = Path(sound_dir)
194
+ self.include_neutral_in_cycle = include_neutral_in_cycle
195
+
196
+ # Composition state
197
+ self.composition_layers = [] # All layers across all cycles
198
+ self.current_cycle = 0
199
+ self.current_step = 0 # Current step within cycle (0-5)
200
+ self.cycle_complete = False
201
+ self.completed_cycles = 0 # Track completed cycles for session management
202
+ self.max_cycles = 2 # Rehabilitation session limit
203
+
204
+ # DJ Effects phase management
205
+ self.current_phase = "building" # "building" or "dj_effects"
206
+ self.mixed_composition_file = None # Path to current mixed composition
207
+ self.active_effects = { # Track which effects are currently active
208
+ "left_hand": False, # Volume fade
209
+ "right_hand": False, # Filter sweep
210
+ "left_leg": False, # Reverb
211
+ "right_leg": False, # Tempo modulation
212
+ "tongue": False # Bass boost
213
+ }
214
+
215
+ # All possible movements (neutral is optional for composition)
216
+ self.all_movements = ["left_hand", "right_hand", "neutral", "left_leg", "tongue", "right_leg"]
217
+
218
+ # Active movements that contribute to composition (excluding neutral)
219
+ self.active_movements = ["left_hand", "right_hand", "left_leg", "tongue", "right_leg"]
220
+
221
+ # Current cycle's random movement sequence (shuffled each cycle)
222
+ self.current_movement_sequence = []
223
+ self.movements_completed = set() # Track which movements have been successfully completed
224
+ self._generate_new_sequence()
225
+
226
+ # User-customizable sound mapping (can be updated after each cycle)
227
+ self.current_sound_mapping = {
228
+ "left_hand": "1_SoundHelix-Song-6_(Bass).wav",
229
+ "right_hand": "1_SoundHelix-Song-6_(Drums).wav",
230
+ "neutral": None, # No sound for neutral/rest state
231
+ "left_leg": "1_SoundHelix-Song-6_(Other).wav",
232
+ "tongue": "1_SoundHelix-Song-6_(Vocals).wav",
233
+ "right_leg": "1_SoundHelix-Song-6_(Bass).wav" # Can be remapped
234
+ }
235
+
236
+ # Available sound files
237
+ self.available_sounds = [
238
+ "1_SoundHelix-Song-6_(Bass).wav",
239
+ "1_SoundHelix-Song-6_(Drums).wav",
240
+ "1_SoundHelix-Song-6_(Other).wav",
241
+ "1_SoundHelix-Song-6_(Vocals).wav"
242
+ ]
243
+
244
+ # Load sound files
245
+ self.loaded_sounds = {}
246
+ self._load_sound_files()
247
+
248
+ # Cycle statistics
249
+ self.cycle_stats = {
250
+ 'total_cycles': 0,
251
+ 'successful_classifications': 0,
252
+ 'total_attempts': 0
253
+ }
254
+
255
+ def _load_sound_files(self):
256
+ """Load all available sound files into memory."""
257
+ for class_name, filename in self.current_sound_mapping.items():
258
+ if filename is None:
259
+ continue
260
+
261
+ file_path = self.sound_dir / filename
262
+ if file_path.exists():
263
+ try:
264
+ data, sample_rate = sf.read(str(file_path))
265
+ self.loaded_sounds[class_name] = {
266
+ 'data': data,
267
+ 'sample_rate': sample_rate,
268
+ 'file_path': str(file_path)
269
+ }
270
+ print(f"Loaded sound for {class_name}: {filename}")
271
+ except Exception as e:
272
+ print(f"Error loading {filename}: {e}")
273
+ else:
274
+ print(f"Sound file not found: {file_path}")
275
+
276
+ def get_sound_for_class(self, class_name: str) -> Optional[Dict]:
277
+ """Get sound data for a specific motor imagery class."""
278
+ return self.loaded_sounds.get(class_name)
279
+
280
+ def _generate_new_sequence(self):
281
+ """Generate a new random sequence of movements for the current cycle."""
282
+ import random
283
+ # Choose which movements to include based on configuration
284
+ movements_for_cycle = self.all_movements.copy() if self.include_neutral_in_cycle else self.active_movements.copy()
285
+ random.shuffle(movements_for_cycle)
286
+ self.current_movement_sequence = movements_for_cycle
287
+ self.movements_completed = set()
288
+
289
+ cycle_size = len(movements_for_cycle)
290
+ print(f"🎯 New random sequence ({cycle_size} movements): {' β†’ '.join([m.replace('_', ' ').title() for m in self.current_movement_sequence])}")
291
+
292
+ def get_current_target_movement(self) -> str:
293
+ """Get the current movement the user should imagine."""
294
+ if self.current_step < len(self.current_movement_sequence):
295
+ return self.current_movement_sequence[self.current_step]
296
+ return "cycle_complete"
297
+
298
+ def get_next_random_movement(self) -> str:
299
+ """Get a random movement from those not yet completed in this cycle."""
300
+ # Use the same movement set as the current cycle
301
+ cycle_movements = self.all_movements if self.include_neutral_in_cycle else self.active_movements
302
+ remaining_movements = [m for m in cycle_movements if m not in self.movements_completed]
303
+ if not remaining_movements:
304
+ return "cycle_complete"
305
+
306
+ import random
307
+ return random.choice(remaining_movements)
308
+
309
+ def process_classification(self, predicted_class: str, confidence: float, threshold: float = 0.7) -> Dict:
310
+ """
311
+ Process a classification result in the context of the current cycle.
312
+ Uses random movement prompting - user can choose any movement at any time.
313
+
314
+ Args:
315
+ predicted_class: The predicted motor imagery class
316
+ confidence: Confidence score
317
+ threshold: Minimum confidence to add sound layer
318
+
319
+ Returns:
320
+ Dictionary with processing results and next action
321
+ """
322
+ self.cycle_stats['total_attempts'] += 1
323
+
324
+ # Check if this movement was already completed in this cycle
325
+ already_completed = predicted_class in self.movements_completed
326
+
327
+ result = {
328
+ 'predicted_class': predicted_class,
329
+ 'confidence': confidence,
330
+ 'above_threshold': confidence >= threshold,
331
+ 'already_completed': already_completed,
332
+ 'sound_added': False,
333
+ 'cycle_complete': False,
334
+ 'next_action': 'continue',
335
+ 'movements_remaining': len(self.current_movement_sequence) - len(self.movements_completed),
336
+ 'movements_completed': list(self.movements_completed)
337
+ }
338
+
339
+ # Check if prediction is above threshold and not already completed
340
+ if result['above_threshold'] and not already_completed:
341
+ # Get sound file for this movement
342
+ sound_file = self.current_sound_mapping.get(predicted_class)
343
+
344
+ if sound_file is None:
345
+ # No sound for this movement (e.g., neutral), but still count as completed
346
+ self.movements_completed.add(predicted_class)
347
+ result['sound_added'] = False # No sound, but movement completed
348
+ self.cycle_stats['successful_classifications'] += 1
349
+ cycle_size = len(self.current_movement_sequence)
350
+ print(f"βœ“ Cycle {self.current_cycle+1}: {predicted_class} completed (no sound) ({len(self.movements_completed)}/{cycle_size} complete)")
351
+ elif predicted_class in self.loaded_sounds:
352
+ # Add sound layer
353
+ layer_info = {
354
+ 'cycle': self.current_cycle,
355
+ 'step': len(self.movements_completed), # Step based on number completed
356
+ 'movement': predicted_class,
357
+ 'sound_file': sound_file,
358
+ 'confidence': confidence,
359
+ 'timestamp': time.time(),
360
+ 'sound_data': self.loaded_sounds[predicted_class]
361
+ }
362
+ self.composition_layers.append(layer_info)
363
+ self.movements_completed.add(predicted_class)
364
+ result['sound_added'] = True
365
+
366
+ print(f"DEBUG: Added layer {len(self.composition_layers)}, total layers now: {len(self.composition_layers)}")
367
+ print(f"DEBUG: Composition layers: {[layer['movement'] for layer in self.composition_layers]}")
368
+
369
+ # Return individual sound file for this classification
370
+ sound_path = os.path.join(self.sound_dir, sound_file)
371
+ result['audio_file'] = sound_path if os.path.exists(sound_path) else None
372
+
373
+ # Also create mixed composition for potential saving (but don't return it)
374
+ mixed_file = self.get_current_mixed_composition()
375
+ result['mixed_composition'] = mixed_file
376
+
377
+ self.cycle_stats['successful_classifications'] += 1
378
+ cycle_size = len(self.current_movement_sequence)
379
+ print(f"βœ“ Cycle {self.current_cycle+1}: Added {sound_file} for {predicted_class} ({len(self.movements_completed)}/{cycle_size} complete)")
380
+
381
+ # Check if cycle is complete (all movements in current sequence completed)
382
+ cycle_movements = self.all_movements if self.include_neutral_in_cycle else self.active_movements
383
+ if len(self.movements_completed) >= len(cycle_movements):
384
+ result['cycle_complete'] = True
385
+ result['next_action'] = 'remap_sounds'
386
+ self.cycle_complete = True
387
+ cycle_size = len(cycle_movements)
388
+ print(f"🎡 Cycle {self.current_cycle+1} complete! All {cycle_size} movements successfully classified!")
389
+
390
+ return result
391
+
392
+ def start_new_cycle(self, new_sound_mapping: Dict[str, str] = None):
393
+ """Start a new composition cycle with optional new sound mapping."""
394
+ if new_sound_mapping:
395
+ self.current_sound_mapping.update(new_sound_mapping)
396
+ print(f"Updated sound mapping for cycle {self.current_cycle+2}")
397
+
398
+ self.current_cycle += 1
399
+ self.current_step = 0
400
+ self.cycle_complete = False
401
+ self.cycle_stats['total_cycles'] += 1
402
+
403
+ # Generate new random sequence for this cycle
404
+ self._generate_new_sequence()
405
+
406
+ print(f"πŸ”„ Starting Cycle {self.current_cycle+1}")
407
+ if self.current_cycle == 1:
408
+ print("πŸ’ͺ Let's create your first brain-music composition!")
409
+ elif self.current_cycle == 2:
410
+ print("οΏ½ Great progress! Let's create your second composition!")
411
+ else:
412
+ print("�🎯 Imagine ANY movement - you can choose the order!")
413
+
414
+ def should_end_session(self) -> bool:
415
+ """Check if the rehabilitation session should end after max cycles."""
416
+ return self.completed_cycles >= self.max_cycles
417
+
418
+ def complete_current_cycle(self):
419
+ """Mark current cycle as complete and track progress."""
420
+ self.completed_cycles += 1
421
+ print(f"βœ… Cycle {self.current_cycle} completed! ({self.completed_cycles}/{self.max_cycles} compositions finished)")
422
+
423
+ def transition_to_dj_phase(self):
424
+ """Transition from building phase to DJ effects phase."""
425
+ if len(self.movements_completed) >= 5: # All movements completed
426
+ self.current_phase = "dj_effects"
427
+ # Create mixed composition for DJ effects
428
+ self._create_mixed_composition()
429
+ print("🎡 Composition Complete! Transitioning to DJ Effects Phase...")
430
+ print("🎧 You are now the DJ! Use movements to control effects:")
431
+ print(" πŸ‘ˆ Left Hand: Volume Fade")
432
+ print(" πŸ‘‰ Right Hand: High Pass Filter")
433
+ print(" 🦡 Left Leg: Reverb Effect")
434
+ print(" 🦡 Right Leg: Low Pass Filter")
435
+ print(" πŸ‘… Tongue: Bass Boost")
436
+ return True
437
+ return False
438
+
439
+ def _create_mixed_composition(self):
440
+ """Create a mixed audio file from all completed layers."""
441
+ try:
442
+ import hashlib
443
+ movement_hash = hashlib.md5(str(sorted(self.movements_completed)).encode()).hexdigest()[:8]
444
+ self.mixed_composition_file = os.path.abspath(f"mixed_composition_{movement_hash}.wav")
445
+
446
+ # FILE SAVING DISABLED: Use existing base audio file instead
447
+ # Try to use the first available completed movement's audio file
448
+ for movement in self.movements_completed:
449
+ if movement in self.current_sound_mapping and self.current_sound_mapping[movement] is not None:
450
+ sound_file = os.path.join(self.sound_dir, self.current_sound_mapping[movement])
451
+ if os.path.exists(sound_file):
452
+ self.mixed_composition_file = os.path.abspath(sound_file)
453
+ print(f"πŸ“€ Using existing audio as mixed composition: {self.mixed_composition_file} (FILE SAVING DISABLED)")
454
+ return
455
+
456
+ # If file already exists, use it
457
+ if os.path.exists(self.mixed_composition_file):
458
+ print(f"πŸ“€ Using existing mixed composition: {self.mixed_composition_file}")
459
+ return
460
+
461
+ # Create actual mixed composition by layering completed sounds
462
+ mixed_data = None
463
+ sample_rate = 44100 # Default sample rate
464
+
465
+ for movement in self.movements_completed:
466
+ if movement in self.current_sound_mapping and self.current_sound_mapping[movement] is not None:
467
+ sound_file = os.path.join(self.sound_dir, self.current_sound_mapping[movement])
468
+ if os.path.exists(sound_file):
469
+ try:
470
+ data, sr = sf.read(sound_file)
471
+ sample_rate = sr
472
+
473
+ # Convert stereo to mono
474
+ if len(data.shape) > 1:
475
+ data = np.mean(data, axis=1)
476
+
477
+ # Initialize or add to mixed data
478
+ if mixed_data is None:
479
+ mixed_data = data * 0.8 # Reduce volume to prevent clipping
480
+ else:
481
+ # Ensure same length by padding shorter audio
482
+ if len(data) > len(mixed_data):
483
+ mixed_data = np.pad(mixed_data, (0, len(data) - len(mixed_data)))
484
+ elif len(mixed_data) > len(data):
485
+ data = np.pad(data, (0, len(mixed_data) - len(data)))
486
+
487
+ # Mix the audio (layer them)
488
+ mixed_data += data * 0.8
489
+ except Exception as e:
490
+ print(f"Error mixing {sound_file}: {e}")
491
+
492
+ # Save mixed composition or create silent fallback
493
+ if mixed_data is not None:
494
+ # Normalize to prevent clipping
495
+ max_val = np.max(np.abs(mixed_data))
496
+ if max_val > 0.95:
497
+ mixed_data = mixed_data * 0.95 / max_val
498
+
499
+ # sf.write(self.mixed_composition_file, mixed_data, sample_rate)
500
+ print(f"πŸ“€ Mixed composition created: {self.mixed_composition_file} (FILE SAVING DISABLED)")
501
+ else:
502
+ # Create silent fallback file
503
+ silent_data = np.zeros(int(sample_rate * 2)) # 2 seconds of silence
504
+ # sf.write(self.mixed_composition_file, silent_data, sample_rate)
505
+ print(f"πŸ“€ Silent fallback composition created: {self.mixed_composition_file} (FILE SAVING DISABLED)")
506
+
507
+ except Exception as e:
508
+ print(f"Error creating mixed composition: {e}")
509
+ # Create minimal fallback file with actual content
510
+ self.mixed_composition_file = os.path.abspath("mixed_composition_fallback.wav")
511
+ try:
512
+ # Create a short silent audio file as fallback
513
+ sample_rate = 44100
514
+ silent_data = np.zeros(int(sample_rate * 2)) # 2 seconds of silence
515
+ # sf.write(self.mixed_composition_file, silent_data, sample_rate)
516
+ print(f"πŸ“€ Silent fallback composition created: {self.mixed_composition_file} (FILE SAVING DISABLED)")
517
+ except Exception as fallback_error:
518
+ print(f"Failed to create fallback file: {fallback_error}")
519
+ self.mixed_composition_file = None
520
+
521
+ def toggle_dj_effect(self, movement: str) -> dict:
522
+ """Toggle a DJ effect for the given movement and process audio."""
523
+ if self.current_phase != "dj_effects":
524
+ return {"effect_applied": False, "message": "Not in DJ effects phase"}
525
+
526
+ if movement not in self.active_effects:
527
+ return {"effect_applied": False, "message": f"Unknown movement: {movement}"}
528
+
529
+ # Toggle the effect
530
+ self.active_effects[movement] = not self.active_effects[movement]
531
+ effect_status = "ON" if self.active_effects[movement] else "OFF"
532
+
533
+ effect_names = {
534
+ "left_hand": "Volume Fade",
535
+ "right_hand": "High Pass Filter",
536
+ "left_leg": "Reverb Effect",
537
+ "right_leg": "Low Pass Filter",
538
+ "tongue": "Bass Boost"
539
+ }
540
+
541
+ effect_name = effect_names.get(movement, movement)
542
+ print(f"πŸŽ›οΈ {effect_name}: {effect_status}")
543
+
544
+ # Process audio with current active effects
545
+ if self.mixed_composition_file and os.path.exists(self.mixed_composition_file):
546
+ processed_file = AudioEffectsProcessor.process_with_effects(
547
+ self.mixed_composition_file,
548
+ self.active_effects
549
+ )
550
+ else:
551
+ # If no mixed composition exists, create one from current sounds
552
+ self._create_mixed_composition()
553
+ # Only process if we successfully created a mixed composition
554
+ if self.mixed_composition_file and os.path.exists(self.mixed_composition_file):
555
+ processed_file = AudioEffectsProcessor.process_with_effects(
556
+ self.mixed_composition_file,
557
+ self.active_effects
558
+ )
559
+ else:
560
+ print("Failed to create mixed composition, using fallback")
561
+ processed_file = self.mixed_composition_file
562
+
563
+ return {
564
+ "effect_applied": True,
565
+ "effect_name": effect_name,
566
+ "effect_status": effect_status,
567
+ "mixed_composition": processed_file
568
+ }
569
+
570
+ def get_cycle_success_rate(self) -> float:
571
+ """Get success rate for current cycle."""
572
+ if self.cycle_stats['total_attempts'] == 0:
573
+ return 0.0
574
+ return self.cycle_stats['successful_classifications'] / self.cycle_stats['total_attempts']
575
+
576
+ def clear_composition(self):
577
+ """Clear all composition layers."""
578
+ self.composition_layers = []
579
+ self.current_composition = None
580
+ print("Composition cleared")
581
+
582
+ def _get_audio_file_for_gradio(self, movement):
583
+ """Get the actual audio file path for Gradio audio output"""
584
+ if movement in self.current_sound_mapping:
585
+ sound_file = self.current_sound_mapping[movement]
586
+ audio_path = os.path.join(self.sound_dir, sound_file)
587
+ if os.path.exists(audio_path):
588
+ return audio_path
589
+ return None
590
+
591
+ def _mix_audio_files(self, audio_files: List[str]) -> str:
592
+ """Mix multiple audio files into a single layered audio file."""
593
+ if not audio_files:
594
+ return None
595
+
596
+ # Load all audio files and convert to mono
597
+ audio_data_list = []
598
+ sample_rate = None
599
+ max_length = 0
600
+
601
+ for file_path in audio_files:
602
+ if os.path.exists(file_path):
603
+ data, sr = sf.read(file_path)
604
+
605
+ # Ensure data is numpy array and handle shape properly
606
+ data = np.asarray(data)
607
+
608
+ # Convert to mono if stereo/multi-channel
609
+ if data.ndim > 1:
610
+ if data.shape[1] > 1: # Multi-channel
611
+ data = np.mean(data, axis=1)
612
+ else: # Single channel with extra dimension
613
+ data = data.flatten()
614
+
615
+ # Ensure final data is 1D
616
+ data = np.asarray(data).flatten()
617
+
618
+ audio_data_list.append(data)
619
+ if sample_rate is None:
620
+ sample_rate = sr
621
+ max_length = max(max_length, len(data))
622
+
623
+ if not audio_data_list:
624
+ return None
625
+
626
+ # Pad all audio to same length and mix
627
+ mixed_audio = np.zeros(max_length)
628
+ for data in audio_data_list:
629
+ # Ensure data is 1D (flatten any remaining multi-dimensional arrays)
630
+ data_flat = np.asarray(data).flatten()
631
+
632
+ # Pad or truncate to match max_length
633
+ if len(data_flat) < max_length:
634
+ padded = np.pad(data_flat, (0, max_length - len(data_flat)), 'constant')
635
+ else:
636
+ padded = data_flat[:max_length]
637
+
638
+ # Add to mix (normalize to prevent clipping)
639
+ mixed_audio += padded / len(audio_data_list)
640
+
641
+ # Create a unique identifier for this composition
642
+ import hashlib
643
+ composition_hash = hashlib.md5(''.join(sorted(audio_files)).encode()).hexdigest()[:8]
644
+
645
+ # FILE SAVING DISABLED: Return first available audio file instead of creating mixed composition
646
+ mixed_audio_path = os.path.join(self.sound_dir, f"mixed_composition_{composition_hash}.wav")
647
+
648
+ # Since file saving is disabled, use the first available audio file from the list
649
+ if audio_files:
650
+ # Use the first audio file as the "mixed" composition
651
+ first_audio_file = os.path.join(self.sound_dir, audio_files[0])
652
+ if os.path.exists(first_audio_file):
653
+ print(f"DEBUG: Using first audio file as mixed composition: {os.path.basename(first_audio_file)} (FILE SAVING DISABLED)")
654
+ return first_audio_file
655
+
656
+ # Fallback: if mixed composition file already exists, use it
657
+ if os.path.exists(mixed_audio_path):
658
+ print(f"DEBUG: Reusing existing mixed audio file: {os.path.basename(mixed_audio_path)}")
659
+ return mixed_audio_path
660
+
661
+ # Final fallback: return first available base audio file
662
+ base_files = ["1_SoundHelix-Song-6_(Vocals).wav", "1_SoundHelix-Song-6_(Drums).wav", "1_SoundHelix-Song-6_(Bass).wav", "1_SoundHelix-Song-6_(Other).wav"]
663
+ for base_file in base_files:
664
+ base_path = os.path.join(self.sound_dir, base_file)
665
+ if os.path.exists(base_path):
666
+ print(f"DEBUG: Using base audio file as fallback: {base_file} (FILE SAVING DISABLED)")
667
+ return base_path
668
+
669
+ return mixed_audio_path
670
+
671
+ def get_current_mixed_composition(self) -> str:
672
+ """Get the current composition as a mixed audio file."""
673
+ # Get all audio files from current composition layers
674
+ audio_files = []
675
+ for layer in self.composition_layers:
676
+ movement = layer.get('movement')
677
+ if movement and movement in self.current_sound_mapping:
678
+ sound_file = self.current_sound_mapping[movement]
679
+ audio_path = os.path.join(self.sound_dir, sound_file)
680
+ if os.path.exists(audio_path):
681
+ audio_files.append(audio_path)
682
+
683
+ # Debug: print current composition state
684
+ print(f"DEBUG: Current composition has {len(self.composition_layers)} layers: {[layer.get('movement') for layer in self.composition_layers]}")
685
+ print(f"DEBUG: Audio files to mix: {[os.path.basename(f) for f in audio_files]}")
686
+
687
+ return self._mix_audio_files(audio_files)
688
+
689
+ def get_composition_info(self) -> Dict:
690
+ """Get comprehensive information about current composition."""
691
+ layers_by_cycle = {}
692
+ for layer in self.composition_layers:
693
+ cycle = layer['cycle']
694
+ if cycle not in layers_by_cycle:
695
+ layers_by_cycle[cycle] = []
696
+ layers_by_cycle[cycle].append({
697
+ 'step': layer['step'],
698
+ 'movement': layer['movement'],
699
+ 'sound_file': layer['sound_file'],
700
+ 'confidence': layer['confidence']
701
+ })
702
+
703
+ # Also track completed movements without sounds (like neutral)
704
+ completed_without_sound = [mov for mov in self.movements_completed
705
+ if self.current_sound_mapping.get(mov) is None]
706
+
707
+ return {
708
+ 'total_cycles': self.current_cycle + (1 if self.composition_layers else 0),
709
+ 'current_cycle': self.current_cycle + 1,
710
+ 'current_step': len(self.movements_completed) + 1, # Current step in cycle
711
+ 'target_movement': self.get_current_target_movement(),
712
+ 'movements_completed': len(self.movements_completed),
713
+ 'movements_remaining': len(self.current_movement_sequence) - len(self.movements_completed),
714
+ 'cycle_complete': self.cycle_complete,
715
+ 'total_layers': len(self.composition_layers),
716
+ 'completed_movements': list(self.movements_completed),
717
+ 'completed_without_sound': completed_without_sound,
718
+ 'layers_by_cycle': layers_by_cycle,
719
+ 'current_mapping': self.current_sound_mapping.copy(),
720
+ 'success_rate': self.get_cycle_success_rate(),
721
+ 'stats': self.cycle_stats.copy()
722
+ }
723
+
724
+ def get_sound_mapping_options(self) -> Dict:
725
+ """Get available sound mapping options for user customization."""
726
+ return {
727
+ 'movements': self.all_movements, # All possible movements for mapping
728
+ 'available_sounds': self.available_sounds,
729
+ 'current_mapping': self.current_sound_mapping.copy()
730
+ }
731
+
732
+ def update_sound_mapping(self, new_mapping: Dict[str, str]) -> bool:
733
+ """Update the sound mapping for future cycles."""
734
+ try:
735
+ # Validate that all sounds exist
736
+ for movement, sound_file in new_mapping.items():
737
+ if sound_file not in self.available_sounds:
738
+ print(f"Warning: Sound file {sound_file} not available")
739
+ return False
740
+
741
+ self.current_sound_mapping.update(new_mapping)
742
+ print("βœ“ Sound mapping updated successfully")
743
+ return True
744
+ except Exception as e:
745
+ print(f"Error updating sound mapping: {e}")
746
+ return False
747
+
748
+ def reset_composition(self):
749
+ """Reset the entire composition to start fresh."""
750
+ self.composition_layers = []
751
+ self.current_cycle = 0
752
+ self.current_step = 0
753
+ self.cycle_complete = False
754
+ self.cycle_stats = {
755
+ 'total_cycles': 0,
756
+ 'successful_classifications': 0,
757
+ 'total_attempts': 0
758
+ }
759
+ print("πŸ”„ Composition reset - ready for new session")
760
+
761
+ def save_composition(self, output_path: str = "composition.wav"):
762
+ """Save the current composition as a mixed audio file."""
763
+ if not self.composition_layers:
764
+ print("No layers to save")
765
+ return False
766
+
767
+ try:
768
+ # For simplicity, we'll just save the latest layer
769
+ # In a real implementation, you'd mix multiple audio tracks
770
+ latest_layer = self.composition_layers[-1]
771
+ sound_data = latest_layer['sound_data']
772
+
773
+ sf.write(output_path, sound_data['data'], sound_data['sample_rate'])
774
+ print(f"Composition saved to {output_path}")
775
+ return True
776
+
777
+ except Exception as e:
778
+ print(f"Error saving composition: {e}")
779
+ return False
780
+
781
+ def get_available_sounds(self) -> List[str]:
782
+ """Get list of available sound classes."""
783
+ return list(self.loaded_sounds.keys())
source/eeg_motor_imagery.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EEG Motor Imagery Classification with Shallow ConvNet
3
+ -----------------------------------------------------
4
+ This script trains and evaluates a ShallowFBCSPNet model
5
+ on motor imagery EEG data stored in .mat files.
6
+ """
7
+
8
+ # === 1. Imports ===
9
+ import scipy.io
10
+ import numpy as np
11
+ import mne
12
+ import torch
13
+ from torch.utils.data import TensorDataset, DataLoader
14
+ from sklearn.model_selection import train_test_split
15
+ from braindecode.models import ShallowFBCSPNet
16
+ import torch.nn as nn
17
+ import pandas as pd
18
+
19
+ # === 2. Data Loading and Epoching ===
20
+ # Load .mat EEG files, create Raw objects, extract events, and epoch the data
21
+ files = [
22
+ "../data/raw_mat/HaLTSubjectA1602236StLRHandLegTongue.mat",
23
+ "../data/raw_mat/HaLTSubjectA1603086StLRHandLegTongue.mat",
24
+
25
+ ]
26
+
27
+ all_epochs = []
28
+ for f in files:
29
+ mat = scipy.io.loadmat(f)
30
+ content = mat['o'][0, 0]
31
+
32
+ labels = content[4].flatten()
33
+ signals = content[5]
34
+ chan_names_raw = content[6]
35
+ channels = [ch[0][0] for ch in chan_names_raw]
36
+ fs = int(content[2][0, 0])
37
+
38
+ df = pd.DataFrame(signals, columns=channels).drop(columns=["X5"], errors="ignore")
39
+ eeg = df.values.T
40
+ ch_names = df.columns.tolist()
41
+
42
+ info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types="eeg")
43
+ raw = mne.io.RawArray(eeg, info)
44
+
45
+ # Create events
46
+ onsets = np.where((labels[1:] != 0) & (labels[:-1] == 0))[0] + 1
47
+ event_codes = labels[onsets].astype(int)
48
+ events = np.c_[onsets, np.zeros_like(onsets), event_codes]
49
+
50
+ # Keep only relevant events
51
+ mask = np.isin(events[:, 2], np.arange(1, 7))
52
+ events = events[mask]
53
+
54
+ event_id = {
55
+ "left_hand": 1,
56
+ "right_hand": 2,
57
+ "neutral": 3,
58
+ "left_leg": 4,
59
+ "tongue": 5,
60
+ "right_leg": 6,
61
+ }
62
+
63
+ # Epoching
64
+ epochs = mne.Epochs(
65
+ raw,
66
+ events=events,
67
+ event_id=event_id,
68
+ tmin=0,
69
+ tmax=1.5,
70
+ baseline=None,
71
+ preload=True,
72
+ )
73
+
74
+ all_epochs.append(epochs)
75
+
76
+ epochs_all = mne.concatenate_epochs(all_epochs)
77
+
78
+ # === 3. Minimal Preprocessing + Train/Validation Split ===
79
+ # Convert epochs to numpy arrays (N, C, T) and split into train/val sets
80
+ X = epochs_all.get_data().astype("float32")
81
+ y = (epochs_all.events[:, -1] - 1).astype("int64") # classes 0..5
82
+
83
+ X_train, X_val, y_train, y_val = train_test_split(
84
+ X, y, test_size=0.2, random_state=42, stratify=y
85
+ )
86
+
87
+ # === 4. Torch DataLoaders ===
88
+ train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
89
+ val_ds = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
90
+
91
+ train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
92
+ val_loader = DataLoader(val_ds, batch_size=32)
93
+
94
+ # === 5. Model – Shallow ConvNet ===
95
+ # Reference: Schirrmeister et al. (2017)
96
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
+
98
+ model = ShallowFBCSPNet(
99
+ n_chans=X.shape[1],
100
+ n_outputs=len(np.unique(y)),
101
+ n_times=X.shape[2],
102
+ final_conv_length="auto"
103
+ ).to(device)
104
+
105
+ # Load pretrained weights
106
+ state_dict = torch.load("shallow_weights_all.pth", map_location=device)
107
+ model.load_state_dict(state_dict)
108
+
109
+ # === 6. Training ===
110
+ criterion = nn.CrossEntropyLoss()
111
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
112
+
113
+ for epoch in range(1, 21):
114
+ # Training
115
+ model.train()
116
+ correct, total = 0, 0
117
+ for xb, yb in train_loader:
118
+ xb, yb = xb.to(device), yb.to(device)
119
+ optimizer.zero_grad()
120
+ out = model(xb)
121
+ loss = criterion(out, yb)
122
+ loss.backward()
123
+ optimizer.step()
124
+
125
+ pred = out.argmax(dim=1)
126
+ correct += (pred == yb).sum().item()
127
+ total += yb.size(0)
128
+ train_acc = correct / total
129
+
130
+ # Validation
131
+ model.eval()
132
+ correct, total = 0, 0
133
+ with torch.no_grad():
134
+ for xb, yb in val_loader:
135
+ xb, yb = xb.to(device), yb.to(device)
136
+ out = model(xb)
137
+ pred = out.argmax(dim=1)
138
+ correct += (pred == yb).sum().item()
139
+ total += yb.size(0)
140
+ val_acc = correct / total
141
+
142
+ print(f"Epoch {epoch:02d} | Train acc: {train_acc:.3f} | Val acc: {val_acc:.3f}")
utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for the EEG Motor Imagery Music Composer
3
+ """
4
+
5
+ import numpy as np
6
+ import logging
7
+ import time
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple
10
+ import json
11
+
12
+ from config import LOG_LEVEL, LOG_FILE, CLASS_NAMES, CLASS_DESCRIPTIONS
13
+
14
+ def setup_logging():
15
+ """Set up logging configuration."""
16
+ LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
17
+
18
+ logging.basicConfig(
19
+ level=getattr(logging, LOG_LEVEL),
20
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
21
+ handlers=[
22
+ logging.FileHandler(LOG_FILE),
23
+ logging.StreamHandler()
24
+ ]
25
+ )
26
+
27
+ return logging.getLogger(__name__)
28
+
29
+ def validate_eeg_data(data: np.ndarray) -> bool:
30
+ """
31
+ Validate EEG data format and dimensions.
32
+
33
+ Args:
34
+ data: EEG data array
35
+
36
+ Returns:
37
+ bool: True if data is valid, False otherwise
38
+ """
39
+ if not isinstance(data, np.ndarray):
40
+ return False
41
+
42
+ if data.ndim not in [2, 3]:
43
+ return False
44
+
45
+ if data.ndim == 2 and data.shape[0] == 0:
46
+ return False
47
+
48
+ if data.ndim == 3 and (data.shape[0] == 0 or data.shape[1] == 0):
49
+ return False
50
+
51
+ return True
52
+
53
+ def format_confidence(confidence: float) -> str:
54
+ """Format confidence score as percentage string."""
55
+ return f"{confidence * 100:.1f}%"
56
+
57
+ def format_timestamp(timestamp: float) -> str:
58
+ """Format timestamp for display."""
59
+ return time.strftime("%H:%M:%S", time.localtime(timestamp))
60
+
61
+ def get_class_emoji(class_name: str) -> str:
62
+ """Get emoji representation for motor imagery class."""
63
+ emoji_map = {
64
+ "left_hand": "🀚",
65
+ "right_hand": "🀚",
66
+ "neutral": "😐",
67
+ "left_leg": "🦡",
68
+ "tongue": "πŸ‘…",
69
+ "right_leg": "🦡"
70
+ }
71
+ return emoji_map.get(class_name, "❓")
72
+
73
+ def create_classification_summary(
74
+ predicted_class: str,
75
+ confidence: float,
76
+ probabilities: Dict[str, float],
77
+ timestamp: Optional[float] = None
78
+ ) -> Dict:
79
+ """
80
+ Create a formatted summary of classification results.
81
+
82
+ Args:
83
+ predicted_class: Predicted motor imagery class
84
+ confidence: Confidence score (0-1)
85
+ probabilities: Dictionary of class probabilities
86
+ timestamp: Optional timestamp
87
+
88
+ Returns:
89
+ Dict: Formatted classification summary
90
+ """
91
+ if timestamp is None:
92
+ timestamp = time.time()
93
+
94
+ return {
95
+ "predicted_class": predicted_class,
96
+ "confidence": confidence,
97
+ "confidence_percent": format_confidence(confidence),
98
+ "probabilities": probabilities,
99
+ "timestamp": timestamp,
100
+ "formatted_time": format_timestamp(timestamp),
101
+ "emoji": get_class_emoji(predicted_class),
102
+ "description": CLASS_DESCRIPTIONS.get(predicted_class, predicted_class)
103
+ }
104
+
105
+ def save_session_data(session_data: Dict, filepath: str) -> bool:
106
+ """
107
+ Save session data to JSON file.
108
+
109
+ Args:
110
+ session_data: Dictionary containing session information
111
+ filepath: Path to save the file
112
+
113
+ Returns:
114
+ bool: True if successful, False otherwise
115
+ """
116
+ try:
117
+ with open(filepath, 'w') as f:
118
+ json.dump(session_data, f, indent=2, default=str)
119
+ return True
120
+ except Exception as e:
121
+ logging.error(f"Error saving session data: {e}")
122
+ return False
123
+
124
+ def load_session_data(filepath: str) -> Optional[Dict]:
125
+ """
126
+ Load session data from JSON file.
127
+
128
+ Args:
129
+ filepath: Path to the JSON file
130
+
131
+ Returns:
132
+ Dict or None: Session data if successful, None otherwise
133
+ """
134
+ try:
135
+ with open(filepath, 'r') as f:
136
+ return json.load(f)
137
+ except Exception as e:
138
+ logging.error(f"Error loading session data: {e}")
139
+ return None
140
+
141
+ def calculate_classification_statistics(history: List[Dict]) -> Dict:
142
+ """
143
+ Calculate statistics from classification history.
144
+
145
+ Args:
146
+ history: List of classification results
147
+
148
+ Returns:
149
+ Dict: Statistics summary
150
+ """
151
+ if not history:
152
+ return {"total": 0, "class_counts": {}, "average_confidence": 0.0}
153
+
154
+ class_counts = {}
155
+ total_confidence = 0.0
156
+
157
+ for item in history:
158
+ class_name = item.get("predicted_class", "unknown")
159
+ confidence = item.get("confidence", 0.0)
160
+
161
+ class_counts[class_name] = class_counts.get(class_name, 0) + 1
162
+ total_confidence += confidence
163
+
164
+ return {
165
+ "total": len(history),
166
+ "class_counts": class_counts,
167
+ "average_confidence": total_confidence / len(history),
168
+ "most_common_class": max(class_counts, key=class_counts.get) if class_counts else None
169
+ }
170
+
171
+ def create_progress_bar(value: float, max_value: float = 1.0, width: int = 20) -> str:
172
+ """
173
+ Create a text-based progress bar.
174
+
175
+ Args:
176
+ value: Current value
177
+ max_value: Maximum value
178
+ width: Width of progress bar in characters
179
+
180
+ Returns:
181
+ str: Progress bar string
182
+ """
183
+ percentage = min(value / max_value, 1.0)
184
+ filled = int(width * percentage)
185
+ bar = "β–ˆ" * filled + "β–‘" * (width - filled)
186
+ return f"[{bar}] {percentage * 100:.1f}%"
187
+
188
+ def validate_audio_file(file_path: str) -> bool:
189
+ """
190
+ Validate if an audio file exists and is readable.
191
+
192
+ Args:
193
+ file_path: Path to audio file
194
+
195
+ Returns:
196
+ bool: True if file is valid, False otherwise
197
+ """
198
+ path = Path(file_path)
199
+ if not path.exists():
200
+ return False
201
+
202
+ if not path.is_file():
203
+ return False
204
+
205
+ # Check file extension
206
+ valid_extensions = ['.wav', '.mp3', '.flac', '.ogg']
207
+ if path.suffix.lower() not in valid_extensions:
208
+ return False
209
+
210
+ return True
211
+
212
+ def generate_composition_filename(prefix: str = "composition") -> str:
213
+ """
214
+ Generate a unique filename for composition exports.
215
+
216
+ Args:
217
+ prefix: Filename prefix
218
+
219
+ Returns:
220
+ str: Unique filename with timestamp
221
+ """
222
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
223
+ return f"{prefix}_{timestamp}.wav"
224
+
225
+ # Initialize logger when module is imported
226
+ logger = setup_logging()