sofieff commited on
Commit
b906dc7
·
1 Parent(s): d7ffdfa

functioning app

Browse files
Files changed (8) hide show
  1. .gitignore +9 -7
  2. app.py +386 -1013
  3. classifier.py +24 -4
  4. config.py +1 -1
  5. data_processor.py +11 -5
  6. sound_library.py +0 -803
  7. sound_manager.py +237 -0
  8. source/eeg_motor_imagery.py +1 -1
.gitignore CHANGED
@@ -47,13 +47,15 @@ Thumbs.db
47
  *.wav
48
  *.mp3
49
  *.pth
50
- mixed_composition_*.wav
51
- *_fx_*.wav
52
- app_old.py
53
- app_new.py
54
- test_mixed.wav
55
- silent.wav
56
 
57
  # Data files
58
  data/
59
- *.mat
 
 
 
 
 
 
 
 
47
  *.wav
48
  *.mp3
49
  *.pth
50
+
 
 
 
 
 
51
 
52
  # Data files
53
  data/
54
+ *.mat
55
+
56
+ otherfiles/
57
+
58
+ app.log
59
+
60
+ source/
61
+ sounds/
app.py CHANGED
@@ -1,1098 +1,471 @@
1
  """
2
- EEG Motor Imagery Music Composer - Redesigned Interface
3
- =======================================================
4
- Brain-Computer Interface that creates music compositions based on imagined movements.
 
5
  """
6
 
7
- import gradio as gr
8
- import numpy as np
 
9
  import matplotlib.pyplot as plt
10
- import time
11
- import threading
12
  import os
13
- from typing import Dict, Tuple, Any, List
14
-
15
- # Import our custom modules
16
- from sound_library import SoundManager
17
  from data_processor import EEGDataProcessor
18
  from classifier import MotorImageryClassifier
19
- from config import DEMO_DATA_PATHS, CLASS_NAMES, CONFIDENCE_THRESHOLD
20
-
21
- def validate_data_setup() -> str:
22
- """Validate that required data files are available."""
23
- missing_files = []
24
-
25
- for subject_id, path in DEMO_DATA_PATHS.items():
26
- try:
27
- import os
28
- if not os.path.exists(path):
29
- missing_files.append(f"Subject {subject_id}: {path}")
30
- except Exception as e:
31
- missing_files.append(f"Subject {subject_id}: Error checking {path}")
32
-
33
- if missing_files:
34
- return f"❌ Missing data files:\n" + "\n".join(missing_files)
35
- return "✅ All data files found"
36
 
37
- # Global app state
38
  app_state = {
39
  'is_running': False,
40
  'demo_data': None,
41
  'demo_labels': None,
42
- 'classification_history': [],
43
  'composition_active': False,
44
  'auto_mode': False
45
  }
46
 
47
- # Initialize components
48
- print("🧠 EEG Motor Imagery Music Composer")
49
- print("=" * 50)
50
- print("Starting Gradio application...")
51
 
52
- try:
53
- sound_manager = SoundManager()
54
- data_processor = EEGDataProcessor()
55
- classifier = MotorImageryClassifier()
56
-
57
- # Load demo data
58
- import os
59
- existing_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)]
60
- if existing_files:
61
- app_state['demo_data'], app_state['demo_labels'] = data_processor.process_files(existing_files)
62
- else:
63
- app_state['demo_data'], app_state['demo_labels'] = None, None
64
-
65
- if app_state['demo_data'] is not None:
66
- # Initialize classifier with proper dimensions
67
- classifier.load_model(n_chans=app_state['demo_data'].shape[1], n_times=app_state['demo_data'].shape[2])
68
- print(f"✅ Pre-trained model loaded successfully from {classifier.model_path}")
69
- print(f"Pre-trained Demo: {len(app_state['demo_data'])} samples from {len(existing_files)} subjects")
70
- else:
71
- print("⚠️ No demo data loaded - check your .mat files")
72
-
73
- print(f"Available sound classes: {list(sound_manager.current_sound_mapping.keys())}")
74
-
75
- except Exception as e:
76
- print(f"❌ Error during initialization: {e}")
77
- raise RuntimeError(
78
- "Cannot initialize app without real EEG data. "
79
- "Please check your data files and paths."
80
- )
81
 
 
82
  def get_movement_sounds() -> Dict[str, str]:
83
  """Get the current sound files for each movement."""
84
  sounds = {}
 
 
 
 
 
85
  for movement, sound_file in sound_manager.current_sound_mapping.items():
86
- if movement in ['left_hand', 'right_hand', 'left_leg', "right_leg"]: # Only show main movements
87
- if sound_file is not None: # Check if sound_file is not None
88
  sound_path = sound_manager.sound_dir / sound_file
89
  if sound_path.exists():
90
- # Convert to absolute path for Gradio audio components
91
- sounds[movement] = str(sound_path.resolve())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  return sounds
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def start_composition():
95
- """Start the composition process and perform initial classification."""
 
 
96
  global app_state
97
-
98
- # Only start new cycle if not already active
99
  if not app_state['composition_active']:
100
  app_state['composition_active'] = True
101
- sound_manager.start_new_cycle() # Reset composition only when starting fresh
102
-
103
  if app_state['demo_data'] is None:
104
  return "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available"
105
-
106
- # Get current target
107
- target_movement = sound_manager.get_current_target_movement()
108
- print(f"DEBUG start_composition: current target = {target_movement}")
109
-
110
- # Check if cycle is complete
111
- if target_movement == "cycle_complete":
112
- return "🎵 Cycle Complete!", "🎵 Complete", "Remap sounds to continue", None, None, None, None, None, None, "Cycle complete - remap sounds to continue"
113
-
114
- # Perform initial EEG classification
115
- epoch_data, true_label = data_processor.simulate_real_time_data(
116
- app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
117
- )
118
-
119
- # Classify the epoch
 
 
120
  predicted_class, confidence, probabilities = classifier.predict(epoch_data)
121
  predicted_name = classifier.class_names[predicted_class]
122
-
123
- # Process classification
124
- result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
125
-
126
- # Debug: Print composition layers before DJ mode transition
127
- print(f"DEBUG: composition_layers before DJ check: {[layer['sound_file'] for layer in sound_manager.composition_layers]}")
128
- # Always check for DJ mode transition after classification
129
- dj_transitioned = sound_manager.transition_to_dj_phase()
130
- print(f"DEBUG: DJ mode transitioned: {dj_transitioned}, current_phase: {sound_manager.current_phase}")
131
-
132
- # Stop EEG visualization if all 4 unique sound layers are present
133
- unique_sounds = set([layer['sound_file'] for layer in sound_manager.composition_layers if layer.get('sound_file')])
134
- if len(unique_sounds) >= 4 and sound_manager.current_phase != "dj_effects":
135
- fig = None
136
- print("DEBUG: EEG visualization stopped (all sounds present)")
137
- else:
138
- fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
139
- if sound_manager.current_phase == "dj_effects":
140
- print("DEBUG: EEG visualization restarted for DJ mode")
141
-
142
- # Initialize all audio components to None (no sound by default)
143
- left_hand_audio = None
144
- right_hand_audio = None
145
- left_leg_audio = None
146
- right_leg_audio = None
147
- #tongue_audio = None
148
-
149
- # Debug: Print classification result
150
- print(f"DEBUG start_composition: predicted={predicted_name}, confidence={confidence:.3f}, sound_added={result['sound_added']}")
151
-
152
- # Only play the sound if it was just added and matches the prediction
153
- if result['sound_added']: # might add confidence threshold for sound output here
154
- sounds = get_movement_sounds()
155
- print(f"DEBUG: Available sounds: {list(sounds.keys())}")
156
- if predicted_name == 'left_hand' and 'left_hand' in sounds:
157
- left_hand_audio = sounds['left_hand']
158
- print(f"DEBUG: Setting left_hand_audio to {sounds['left_hand']}")
159
- elif predicted_name == 'right_hand' and 'right_hand' in sounds:
160
- right_hand_audio = sounds['right_hand']
161
- print(f"DEBUG: Setting right_hand_audio to {sounds['right_hand']}")
162
- elif predicted_name == 'left_leg' and 'left_leg' in sounds:
163
- left_leg_audio = sounds['left_leg']
164
- print(f"DEBUG: Setting left_leg_audio to {sounds['left_leg']}")
165
- elif predicted_name == 'right_leg' and 'right_leg' in sounds:
166
- right_leg_audio = sounds['right_leg']
167
- print(f"DEBUG: Setting right_leg_audio to {sounds['right_leg']}")
168
- # elif predicted_name == 'tongue' and 'tongue' in sounds:
169
- # tongue_audio = sounds['tongue']
170
- # print(f"DEBUG: Setting tongue_audio to {sounds['tongue']}")
171
  else:
172
- print("DEBUG: No sound added - confidence too low or other issue")
173
-
174
- # Format next target with progress information
175
- next_target = sound_manager.get_current_target_movement()
176
- completed_count = len(sound_manager.movements_completed)
177
- total_count = len(sound_manager.current_movement_sequence)
178
-
179
- if next_target == "cycle_complete":
180
- target_text = "🎵 Cycle Complete!"
181
- else:
182
- target_text = f"🎯 Any Movement ({completed_count}/{total_count} complete) - Use 'Classify Epoch' button to continue"
183
-
184
- predicted_text = f"🧠 Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
185
 
186
- # Get composition info
187
- composition_info = sound_manager.get_composition_info()
188
- status_text = format_composition_summary(composition_info)
189
-
190
- return (
191
- target_text,
192
- predicted_text,
193
- "2-3 seconds",
194
- fig,
195
- left_hand_audio,
196
- right_hand_audio,
197
- left_leg_audio,
198
- right_leg_audio,
199
- #tongue_audio,
200
- status_text
201
- )
202
 
203
- def stop_composition():
204
- """Stop the composition process."""
205
- global app_state
206
- app_state['composition_active'] = False
207
- app_state['auto_mode'] = False
208
- return (
209
- "Composition stopped. Click 'Start Composing' to begin again",
210
- "--",
211
- "--",
212
- "Stopped - click Start to resume"
213
- )
214
 
215
- def start_automatic_composition():
216
- """Start automatic composition with continuous classification."""
217
- global app_state
 
 
 
218
 
219
- # Only start new cycle if not already active
220
- if not app_state['composition_active']:
221
- app_state['composition_active'] = True
222
- app_state['auto_mode'] = True
223
- sound_manager.start_new_cycle() # Reset composition only when starting fresh
224
 
225
- if app_state['demo_data'] is None:
226
- return "❌ No data", "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available", gr.update(visible=True), gr.update(visible=False)
227
 
228
- # Get current target
229
- target_movement = sound_manager.get_current_target_movement()
230
- print(f"DEBUG start_automatic_composition: current target = {target_movement}")
231
-
232
- # Check if cycle is complete
233
- if target_movement == "cycle_complete":
234
- # Mark current cycle as complete
235
- sound_manager.complete_current_cycle()
236
-
237
- # Check if rehabilitation session should end
238
- if sound_manager.should_end_session():
239
- app_state['auto_mode'] = False # Stop automatic mode
240
- return (
241
- "🎉 Session Complete!",
242
- "🏆 Amazing Progress!",
243
- "Rehabilitation session finished!",
244
- "🌟 Congratulations! You've created 2 unique brain-music compositions!\n\n" +
245
- "💪 Your motor imagery skills are improving!\n\n" +
246
- "🎵 You can review your compositions above, or start a new session anytime.\n\n" +
247
- "Would you like to continue with more cycles, or take a well-deserved break?",
248
- None, None, None, None, None, None,
249
- f"✅ Session Complete: {sound_manager.completed_cycles}/{sound_manager.max_cycles} compositions finished!"
250
- )
251
- else:
252
- # Start next cycle automatically
253
- sound_manager.start_new_cycle()
254
- print("🔄 Cycle completed! Starting new cycle automatically...")
255
- target_movement = sound_manager.get_current_target_movement() # Get new target
256
-
257
- # Show user prompt - encouraging start message
258
- cycle_num = sound_manager.current_cycle
259
- if cycle_num == 1:
260
- prompt_text = "🌟 Welcome to your rehabilitation session! Let's start with any movement you can imagine..."
261
- elif cycle_num == 2:
262
- prompt_text = "💪 Excellent work on your first composition! Ready for composition #2?"
263
- else:
264
- prompt_text = "🧠 Let's continue - imagine any movement now..."
265
-
266
- # Perform initial EEG classification
267
- epoch_data, true_label = data_processor.simulate_real_time_data(
268
- app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
269
- )
270
-
271
- # Classify the epoch
272
- predicted_class, confidence, probabilities = classifier.predict(epoch_data)
273
- predicted_name = classifier.class_names[predicted_class]
274
-
275
- # Handle DJ effects or building phase
276
- if sound_manager.current_phase == "dj_effects" and confidence > CONFIDENCE_THRESHOLD:
277
- # DJ Effects Mode - toggle effects instead of adding sounds
278
- dj_result = sound_manager.toggle_dj_effect(predicted_name)
279
- result = {
280
- 'sound_added': dj_result['effect_applied'],
281
- 'mixed_composition': dj_result.get('mixed_composition'),
282
- 'effect_name': dj_result.get('effect_name', ''),
283
- 'effect_status': dj_result.get('effect_status', '')
284
- }
285
- else:
286
- # Building Mode - process classification normally
287
- result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
288
-
289
- # Create visualization
290
- fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
291
-
292
-
293
- # Initialize all audio components to None (no sound by default)
294
- left_hand_audio = None
295
- right_hand_audio = None
296
- left_leg_audio = None
297
- right_leg_audio = None
298
- #tongue_audio = None
299
-
300
- # Debug: Print classification result
301
- print(f"DEBUG start_automatic_composition: predicted={predicted_name}, confidence={confidence:.3f}, sound_added={result['sound_added']}")
302
-
303
- # Handle audio display based on current phase
304
- if sound_manager.current_phase == "dj_effects":
305
- # DJ Effects Phase - continue showing individual sounds in their players
306
- sounds = get_movement_sounds()
307
- # Always check for phase transition to DJ effects after each classification
308
- sound_manager.transition_to_dj_phase()
309
- completed_movements = sound_manager.movements_completed
310
- # Display each completed movement sound in its respective player (same as building mode)
311
- if 'left_hand' in completed_movements and 'left_hand' in sounds:
312
- left_hand_audio = sounds['left_hand']
313
- print(f"DEBUG DJ: Left hand playing: {sounds['left_hand']}")
314
- if 'right_hand' in completed_movements and 'right_hand' in sounds:
315
- right_hand_audio = sounds['right_hand']
316
- print(f"DEBUG DJ: Right hand playing: {sounds['right_hand']}")
317
- if 'left_leg' in completed_movements and 'left_leg' in sounds:
318
- left_leg_audio = sounds['left_leg']
319
- print(f"DEBUG DJ: Left leg playing: {sounds['left_leg']}")
320
- if 'right_leg' in completed_movements and 'right_leg' in sounds:
321
- right_leg_audio = sounds['right_leg']
322
- print(f"DEBUG DJ: Right leg playing: {sounds['right_leg']}")
323
- # if 'tongue' in completed_movements and 'tongue' in sounds:
324
- # tongue_audio = sounds['tongue']
325
- # print(f"DEBUG DJ: Tongue playing: {sounds['tongue']}")
326
- print(f"DEBUG DJ: {len(completed_movements)} individual sounds playing with effects applied")
327
- else:
328
- # Building Phase - create and show layered composition
329
- sounds = get_movement_sounds()
330
- completed_movements = sound_manager.movements_completed
331
- print(f"DEBUG: Available sounds: {list(sounds.keys())}")
332
- print(f"DEBUG: Completed movements: {completed_movements}")
333
-
334
- # Display individual sounds in their respective players for layered composition
335
- # All completed movement sounds will play simultaneously, creating natural layering
336
- if len(completed_movements) > 0:
337
- print(f"DEBUG: Showing individual sounds that will layer together: {list(completed_movements)}")
338
-
339
- # Display each completed movement sound in its respective player
340
- if 'left_hand' in completed_movements and 'left_hand' in sounds:
341
- left_hand_audio = sounds['left_hand']
342
- print(f"DEBUG: Left hand playing: {sounds['left_hand']}")
343
- if 'right_hand' in completed_movements and 'right_hand' in sounds:
344
- right_hand_audio = sounds['right_hand']
345
- print(f"DEBUG: Right hand playing: {sounds['right_hand']}")
346
- if 'left_leg' in completed_movements and 'left_leg' in sounds:
347
- left_leg_audio = sounds['left_leg']
348
- print(f"DEBUG: Left leg playing: {sounds['left_leg']}")
349
- if 'right_leg' in completed_movements and 'right_leg' in sounds:
350
- right_leg_audio = sounds['right_leg']
351
- print(f"DEBUG: Right leg playing: {sounds['right_leg']}")
352
- # if 'tongue' in completed_movements and 'tongue' in sounds:
353
- # tongue_audio = sounds['tongue']
354
- # print(f"DEBUG: Tongue playing: {sounds['tongue']}")
355
-
356
- print(f"DEBUG: {len(completed_movements)} individual sounds will play together creating layered composition")
357
-
358
- # Check for phase transition to DJ effects
359
- completed_count = len(sound_manager.movements_completed)
360
- total_count = len(sound_manager.current_movement_sequence)
361
-
362
- # Transition to DJ effects if all movements completed but still in building phase
363
- if completed_count >= total_count and sound_manager.current_phase == "building":
364
- sound_manager.transition_to_dj_phase()
365
-
366
- # Format display based on current phase
367
- if sound_manager.current_phase == "dj_effects":
368
- target_text = "🎧 DJ Mode Active - Use movements to control effects!"
369
- else:
370
- next_target = sound_manager.get_current_target_movement()
371
- if next_target == "cycle_complete":
372
- target_text = "🎵 Composition Complete!"
373
- else:
374
- target_text = f"🎯 Building Composition ({completed_count}/{total_count} layers)"
375
-
376
- # Update display text based on phase
377
- if sound_manager.current_phase == "dj_effects":
378
- if result.get('effect_name') and result.get('effect_status'):
379
- predicted_text = f"🎛️ {result['effect_name']}: {result['effect_status']}"
380
- else:
381
- predicted_text = f"🧠 Detected: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
382
- timer_text = "🎧 DJ Mode - Effects updating every 3 seconds..."
383
- else:
384
- predicted_text = f"🧠 Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
385
- timer_text = "⏱️ Next trial in 2-3 seconds..."
386
-
387
- # Get composition info
388
- composition_info = sound_manager.get_composition_info()
389
- status_text = format_composition_summary(composition_info)
390
-
391
- # Phase-based instruction visibility
392
- building_visible = sound_manager.current_phase == "building"
393
- dj_visible = sound_manager.current_phase == "dj_effects"
394
-
395
- return (
396
- target_text,
397
- predicted_text,
398
- timer_text,
399
- prompt_text,
400
- fig,
401
- left_hand_audio,
402
- right_hand_audio,
403
- left_leg_audio,
404
- right_leg_audio,
405
- #tongue_audio,
406
- status_text,
407
- gr.update(visible=building_visible), # building_instructions
408
- gr.update(visible=dj_visible) # dj_instructions
409
- )
410
 
411
- def manual_classify():
412
- """Manual classification for testing purposes."""
413
- global app_state
414
-
415
- if app_state['demo_data'] is None:
416
- return "❌ No data", "❌ No data", "Manual mode", None, "No EEG data available", None, None, None, None, None
417
-
418
- # Get EEG data sample
419
- epoch_data, true_label = data_processor.simulate_real_time_data(
420
- app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
421
- )
422
-
423
- # Classify the epoch
424
- predicted_class, confidence, probabilities = classifier.predict(epoch_data)
425
- predicted_name = classifier.class_names[predicted_class]
426
-
427
- # Create visualization (without composition context)
428
- fig = create_eeg_plot(epoch_data, "manual_test", predicted_name, confidence, False)
429
-
430
- # Format results
431
- target_text = "🎯 Manual Test Mode"
432
- predicted_text = f"🧠 {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
433
-
434
- # Update results log
435
- import time
436
- timestamp = time.strftime("%H:%M:%S")
437
- result_entry = f"[{timestamp}] Predicted: {predicted_name.replace('_', ' ').title()} (confidence: {confidence:.3f})"
438
-
439
- # Get sound files for preview (no autoplay)
440
- sounds = get_movement_sounds()
441
- left_hand_audio = sounds.get('left_hand', None)
442
- right_hand_audio = sounds.get('right_hand', None)
443
- left_leg_audio = sounds.get('left_leg', None)
444
- right_leg_audio = sounds.get('right_leg', None)
445
- #tongue_audio = sounds.get('tongue', None)
446
-
447
- return (
448
- target_text,
449
- predicted_text,
450
- "Manual mode - click button to classify",
451
- fig,
452
- result_entry,
453
- left_hand_audio,
454
- right_hand_audio,
455
- left_leg_audio,
456
- right_leg_audio,
457
- #tongue_audio
458
- )
459
 
460
- def clear_manual():
461
- """Clear manual testing results."""
462
- return (
463
- "🎯 Manual Test Mode",
464
- "--",
465
- "Manual mode",
466
- None,
467
- "Manual classification results cleared...",
468
- None, None, None, None, None
469
- )
470
 
471
- def continue_automatic_composition():
472
- """Continue automatic composition - called for subsequent trials."""
473
- global app_state
474
-
475
- if not app_state['composition_active'] or not app_state['auto_mode']:
476
- return "🛑 Stopped", "--", "--", "Automatic composition stopped", None, None, None, None, None, None, "Stopped", gr.update(visible=True), gr.update(visible=False)
477
-
478
- if app_state['demo_data'] is None:
479
- return "❌ No data", "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available", gr.update(visible=True), gr.update(visible=False)
480
-
481
- # Get current target
482
- target_movement = sound_manager.get_current_target_movement()
483
- print(f"DEBUG continue_automatic_composition: current target = {target_movement}")
484
-
485
- # Check if cycle is complete
486
- if target_movement == "cycle_complete":
487
- # Mark current cycle as complete
488
- sound_manager.complete_current_cycle()
489
-
490
- # Check if rehabilitation session should end
491
- if sound_manager.should_end_session():
492
- app_state['auto_mode'] = False # Stop automatic mode
493
- return (
494
- "🎉 Session Complete!",
495
- "🏆 Amazing Progress!",
496
- "Rehabilitation session finished!",
497
- "🌟 Congratulations! You've created 2 unique brain-music compositions!\n\n" +
498
- "💪 Your motor imagery skills are improving!\n\n" +
499
- "🎵 You can review your compositions above, or start a new session anytime.\n\n" +
500
- "Would you like to continue with more cycles, or take a well-deserved break?",
501
- None, None, None, None, None, None,
502
- f"✅ Session Complete: {sound_manager.completed_cycles}/{sound_manager.max_cycles} compositions finished!",
503
- gr.update(visible=True), gr.update(visible=False)
504
- )
505
- else:
506
- # Start next cycle automatically
507
- sound_manager.start_new_cycle()
508
- print("🔄 Cycle completed! Starting new cycle automatically...")
509
- target_movement = sound_manager.get_current_target_movement() # Get new target
510
-
511
- # Show next user prompt - rehabilitation-focused messaging
512
- prompts = [
513
- "💪 Great work! Imagine your next movement...",
514
- "🎯 You're doing amazing! Focus and imagine any movement...",
515
- "✨ Excellent progress! Ready for the next movement?",
516
- "🌟 Keep it up! Concentrate and imagine now...",
517
- "🏆 Fantastic! Next trial - imagine any movement..."
518
- ]
519
- import random
520
- prompt_text = random.choice(prompts)
521
-
522
- # Add progress encouragement
523
- completed_count = len(sound_manager.movements_completed)
524
- total_count = len(sound_manager.current_movement_sequence)
525
- if completed_count > 0:
526
- prompt_text += f" ({completed_count}/{total_count} movements completed this cycle)"
527
-
528
- # Perform EEG classification
529
- epoch_data, true_label = data_processor.simulate_real_time_data(
530
- app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
531
- )
532
-
533
- # Classify the epoch
534
- predicted_class, confidence, probabilities = classifier.predict(epoch_data)
535
- predicted_name = classifier.class_names[predicted_class]
536
-
537
- # Handle DJ effects or building phase
538
- if sound_manager.current_phase == "dj_effects" and confidence > CONFIDENCE_THRESHOLD:
539
- # DJ Effects Mode - toggle effects instead of adding sounds
540
- dj_result = sound_manager.toggle_dj_effect(predicted_name)
541
- result = {
542
- 'sound_added': dj_result['effect_applied'],
543
- 'mixed_composition': dj_result.get('mixed_composition'),
544
- 'effect_name': dj_result.get('effect_name', ''),
545
- 'effect_status': dj_result.get('effect_status', '')
546
- }
547
- else:
548
- # Building Mode - process classification normally
549
- result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
550
-
551
- # Check if we should transition to DJ phase
552
- completed_count = len(sound_manager.movements_completed)
553
- if completed_count >= 5 and sound_manager.current_phase == "building":
554
- if sound_manager.transition_to_dj_phase():
555
- print(f"DEBUG: Successfully transitioned to DJ phase with {completed_count} completed movements")
556
-
557
- # Create visualization
558
- fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
559
-
560
- # Initialize all audio components to None (no sound by default)
561
- left_hand_audio = None
562
- right_hand_audio = None
563
- left_leg_audio = None
564
- right_leg_audio = None
565
- #tongue_audio = None
566
-
567
- # Handle audio differently based on phase
568
- if sound_manager.current_phase == "dj_effects":
569
- # DJ Mode: Show only the full mixed track and route effects to it
570
- mixed_track = sound_manager.get_current_mixed_composition()
571
- print(f"DEBUG continue: DJ Mode - Showing full mixed track: {mixed_track}")
572
- # Hide individual movement sounds by setting them to None
573
- left_hand_audio = None
574
- right_hand_audio = None
575
- left_leg_audio = None
576
- right_leg_audio = None
577
- #tongue_audio = None
578
- # The mixed_track will be shown in a dedicated gr.Audio component in the UI (update UI accordingly)
579
- else:
580
- # Building Mode: Display individual sounds in their respective players for layered composition
581
- # All completed movement sounds will play simultaneously, creating natural layering
582
- sounds = get_movement_sounds()
583
- completed_movements = sound_manager.movements_completed
584
- print(f"DEBUG continue: Available sounds: {list(sounds.keys())}")
585
- print(f"DEBUG continue: Completed movements: {completed_movements}")
586
-
587
- if len(completed_movements) > 0:
588
- # Track and print only the sounds that have been added
589
- sounds_added = [sounds[m] for m in completed_movements if m in sounds]
590
- print(f"DEBUG: Sounds added to composition: {sounds_added}")
591
- # Display each completed movement sound in its respective player
592
- if 'left_hand' in completed_movements and 'left_hand' in sounds:
593
- left_hand_audio = sounds['left_hand']
594
- if 'right_hand' in completed_movements and 'right_hand' in sounds:
595
- right_hand_audio = sounds['right_hand']
596
- if 'left_leg' in completed_movements and 'left_leg' in sounds:
597
- left_leg_audio = sounds['left_leg']
598
- if 'right_leg' in completed_movements and 'right_leg' in sounds:
599
- right_leg_audio = sounds['right_leg']
600
-
601
- # Format display with progress information
602
- completed_count = len(sound_manager.movements_completed)
603
- total_count = len(sound_manager.current_movement_sequence)
604
-
605
- if sound_manager.current_phase == "dj_effects":
606
- target_text = f"🎧 DJ Mode - Control Effects with Movements"
607
- predicted_text = f"🧠 Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
608
- if result.get('effect_applied'):
609
- effect_name = result.get('effect_name', '')
610
- effect_status = result.get('effect_status', '')
611
- timer_text = f"�️ {effect_name}: {effect_status}"
612
- else:
613
- timer_text = "🎵 Move to control effects..."
614
- else:
615
- target_text = f"�🎯 Any Movement ({completed_count}/{total_count} complete)"
616
- predicted_text = f"🧠 Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
617
- timer_text = "⏱️ Next trial in 2-3 seconds..." if app_state['auto_mode'] else "Stopped"
618
-
619
- # Get composition info
620
  composition_info = sound_manager.get_composition_info()
621
  status_text = format_composition_summary(composition_info)
622
-
623
- # Phase-based instruction visibility
624
- building_visible = sound_manager.current_phase == "building"
625
- dj_visible = sound_manager.current_phase == "dj_effects"
626
-
627
  return (
628
- target_text,
629
- predicted_text,
630
- timer_text,
631
- prompt_text,
632
  fig,
633
  left_hand_audio,
634
- right_hand_audio,
635
  left_leg_audio,
636
  right_leg_audio,
637
- #tongue_audio,
638
- status_text,
639
- gr.update(visible=building_visible), # building_instructions
640
- gr.update(visible=dj_visible) # dj_instructions
641
  )
642
 
643
- def classify_epoch():
644
- """Classify a single EEG epoch and update composition."""
645
  global app_state
646
-
647
  if not app_state['composition_active']:
648
  return "❌ Not active", "❌ Not active", "❌ Not active", None, None, None, None, None, None, "Click 'Start Composing' first"
649
-
650
  if app_state['demo_data'] is None:
651
  return "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available"
652
-
653
- # Get current target
654
- target_movement = sound_manager.get_current_target_movement()
655
- print(f"DEBUG classify_epoch: current target = {target_movement}")
656
-
657
- if target_movement == "cycle_complete":
658
- return "🎵 Cycle Complete!", "🎵 Complete", "Remap sounds to continue", None, None, None, None, None, None, "Cycle complete - remap sounds to continue"
659
-
660
- # Get EEG data sample
661
- epoch_data, true_label = data_processor.simulate_real_time_data(
662
- app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
663
- )
664
-
665
- # Classify the epoch
666
  predicted_class, confidence, probabilities = classifier.predict(epoch_data)
667
  predicted_name = classifier.class_names[predicted_class]
668
-
669
- # Process classification
670
- result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
671
-
672
- # Check if we should transition to DJ phase
673
- completed_count = len(sound_manager.movements_completed)
674
- if completed_count >= 5 and sound_manager.current_phase == "building":
675
- if sound_manager.transition_to_dj_phase():
676
- print(f"DEBUG: Successfully transitioned to DJ phase with {completed_count} completed movements")
677
-
678
- # Create visualization
679
- fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
680
-
681
- # Initialize all audio components to None (no sound by default)
682
- left_hand_audio = None
683
- right_hand_audio = None
684
- left_leg_audio = None
685
- right_leg_audio = None
686
- #tongue_audio = None
687
-
688
- # Always assign all completed movement sounds to their respective audio slots
689
  sounds = get_movement_sounds()
690
  completed_movements = sound_manager.movements_completed
691
- if 'left_hand' in completed_movements and 'left_hand' in sounds:
692
- left_hand_audio = sounds['left_hand']
693
- if 'right_hand' in completed_movements and 'right_hand' in sounds:
694
- right_hand_audio = sounds['right_hand']
695
- if 'left_leg' in completed_movements and 'left_leg' in sounds:
696
- left_leg_audio = sounds['left_leg']
697
- if 'right_leg' in completed_movements and 'right_leg' in sounds:
698
- right_leg_audio = sounds['right_leg']
699
- # if 'tongue' in completed_movements and 'tongue' in sounds:
700
- # tongue_audio = sounds['tongue']
701
-
702
- # Format next target
703
- next_target = sound_manager.get_current_target_movement()
704
- target_text = f"🎯 Target: {next_target.replace('_', ' ').title()}" if next_target != "cycle_complete" else "🎵 Cycle Complete!"
705
-
706
- predicted_text = f"🧠 Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
707
-
708
- # Get composition info
 
 
709
  composition_info = sound_manager.get_composition_info()
710
  status_text = format_composition_summary(composition_info)
711
-
 
712
  return (
713
- target_text,
714
- predicted_text,
715
- "2-3 seconds",
716
- fig,
717
- left_hand_audio,
718
- right_hand_audio,
719
- left_leg_audio,
720
- right_leg_audio,
721
- ##tongue_audio,
722
- status_text
723
  )
724
 
725
- def create_eeg_plot(eeg_data: np.ndarray, target_movement: str, predicted_name: str, confidence: float, sound_added: bool) -> plt.Figure:
726
- """Create EEG plot with target movement and classification result."""
727
- fig, axes = plt.subplots(2, 2, figsize=(12, 8))
728
- axes = axes.flatten()
729
-
730
- # Plot 4 channels
731
- time_points = np.arange(eeg_data.shape[1]) / 200 # 200 Hz sampling rate
732
- channel_names = ['C3', 'C4', 'T3', 'T4'] # Motor cortex channels
733
-
734
- for i in range(min(4, eeg_data.shape[0])):
735
- color = 'green' if sound_added else 'blue'
736
- axes[i].plot(time_points, eeg_data[i], color=color, linewidth=1)
737
-
738
- if i < len(channel_names):
739
- axes[i].set_title(f'{channel_names[i]} (Ch {i+1})')
740
- else:
741
- axes[i].set_title(f'Channel {i+1}')
742
-
743
- axes[i].set_xlabel('Time (s)')
744
- axes[i].set_ylabel('Amplitude (µV)')
745
- axes[i].grid(True, alpha=0.3)
746
-
747
- # Add overall title with status
748
- status = "✓ SOUND ADDED" if sound_added else "○ No sound"
749
- title = f"Target: {target_movement.replace('_', ' ').title()} | Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f}) | {status}"
750
- fig.suptitle(title, fontsize=12, fontweight='bold')
751
- fig.tight_layout()
752
- return fig
753
 
754
- def format_composition_summary(composition_info: Dict) -> str:
755
- """Format composition information for display."""
756
- if not composition_info.get('layers_by_cycle'):
757
- return "No composition layers yet"
758
-
759
- summary = []
760
- for cycle, layers in composition_info['layers_by_cycle'].items():
761
- summary.append(f"Cycle {cycle + 1}: {len(layers)} layers")
762
- for layer in layers:
763
- movement = layer.get('movement', 'unknown')
764
- confidence = layer.get('confidence', 0)
765
- summary.append(f" • {movement.replace('_', ' ').title()} ({confidence:.2f})")
766
-
767
- return "\n".join(summary) if summary else "No composition layers"
768
 
769
- # Create Gradio interface
770
- def create_interface():
771
- with gr.Blocks(title="EEG Motor Imagery Music Composer", theme=gr.themes.Soft()) as demo:
772
- gr.Markdown("# 🧠🎵 EEG Motor Imagery Rehabilitation Composer")
773
- gr.Markdown("**Therapeutic Brain-Computer Interface for Motor Recovery**\n\nCreate beautiful music compositions using your brain signals! This rehabilitation tool helps strengthen motor imagery skills while creating personalized musical pieces.")
774
-
775
- with gr.Tabs() as tabs:
776
- # Main Composition Tab
777
- with gr.TabItem("🎵 Automatic Composition"):
778
  with gr.Row():
779
- # Left side - Task and EEG information
780
  with gr.Column(scale=2):
781
- # Task instructions - Building Phase
782
- with gr.Group() as building_instructions:
783
- gr.Markdown("### 🎯 Rehabilitation Session Instructions")
784
- gr.Markdown("""
785
- **Motor Imagery Training:**
786
- - **Imagine** opening or closing your **right or left hand**
787
- - **Visualize** briefly moving your **right or left leg or foot**
788
-
789
- *🌟 Each successful imagination creates a musical layer!*
790
-
791
- **Session Structure:** Build composition, then control DJ effects
792
- *Press Start to begin your personalized rehabilitation session*
793
- """)
794
-
795
- # DJ Instructions - Effects Phase (initially hidden)
796
- with gr.Group(visible=False) as dj_instructions:
797
- gr.Markdown("### 🎧 DJ Controller Mode")
798
- gr.Markdown("""
799
- **🎉 Composition Complete! You are now the DJ!**
800
-
801
- **Use the same movements to control audio effects:**
802
- - 👈 **Left Hand**: Volume Fade On/Off
803
- - 👉 **Right Hand**: High Pass Filter On/Off
804
- - 🦵 **Left Leg**: Reverb Effect On/Off
805
- - 🦵 **Right Leg**: Low Pass Filter On/Off
806
-
807
-
808
- *🎛️ Each movement toggles an effect - Mix your creation!*
809
- """)
810
-
811
- # Start button
812
- with gr.Row():
813
- start_btn = gr.Button("🎵 Start Composing", variant="primary", size="lg")
814
- continue_btn = gr.Button("⏭️ Continue", variant="primary", size="lg", visible=False)
815
- stop_btn = gr.Button("🛑 Stop", variant="secondary", size="lg")
816
-
817
- # Session completion options (shown after 2 cycles)
818
- with gr.Row(visible=False) as session_complete_row:
819
- new_session_btn = gr.Button("🔄 Start New Session", variant="primary", size="lg")
820
- extend_session_btn = gr.Button("➕ Continue Session", variant="secondary", size="lg")
821
-
822
- # Timer for automatic progression (hidden from user)
823
- timer = gr.Timer(value=3.0, active=False) # 3 second intervals
824
-
825
- # User prompt display
826
- user_prompt = gr.Textbox(label="💭 User Prompt", interactive=False, value="Click 'Start Composing' to begin",
827
- elem_classes=["prompt-display"])
828
-
829
- # Current status
830
- with gr.Row():
831
- target_display = gr.Textbox(label="🎯 Current Target", interactive=False, value="Ready to start")
832
- predicted_display = gr.Textbox(label="🧠 Predicted", interactive=False, value="--")
833
-
834
- timer_display = gr.Textbox(label="⏱️ Next Trial In", interactive=False, value="--")
835
-
836
  eeg_plot = gr.Plot(label="EEG Data Visualization")
837
-
838
- # Right side - Compositional layers
839
  with gr.Column(scale=1):
840
- gr.Markdown("### 🎵 Compositional Layers")
841
-
842
- # Show 5 movement sounds
843
  left_hand_sound = gr.Audio(label="👈 Left Hand", interactive=False, autoplay=True, visible=True)
844
- right_hand_sound = gr.Audio(label="👉 Right Hand", interactive=False, autoplay=True, visible=True)
845
  left_leg_sound = gr.Audio(label="🦵 Left Leg", interactive=False, autoplay=True, visible=True)
846
  right_leg_sound = gr.Audio(label="🦵 Right Leg", interactive=False, autoplay=True, visible=True)
847
- #tongue_sound = gr.Audio(label="👅 Tongue", interactive=False, autoplay=True, visible=True)
848
-
849
- # Composition status
850
  composition_status = gr.Textbox(label="Composition Status", interactive=False, lines=5)
851
-
852
- # Manual Testing Tab
853
- with gr.TabItem("🧠 Manual Testing"):
854
- with gr.Row():
855
- with gr.Column(scale=2):
856
- gr.Markdown("### 🔬 Manual EEG Classification Testing")
857
- gr.Markdown("Use this tab to manually test the EEG classifier without the composition system.")
858
-
859
- with gr.Row():
860
- classify_btn = gr.Button("🧠 Classify Single Epoch", variant="primary")
861
- clear_btn = gr.Button("�️ Clear", variant="secondary")
862
-
863
- # Manual status displays
864
- manual_target_display = gr.Textbox(label="🎯 Current Target", interactive=False, value="Ready")
865
- manual_predicted_display = gr.Textbox(label="🧠 Predicted", interactive=False, value="--")
866
- manual_timer_display = gr.Textbox(label="⏱️ Status", interactive=False, value="Manual mode")
867
-
868
- manual_eeg_plot = gr.Plot(label="EEG Data Visualization")
869
-
870
- with gr.Column(scale=1):
871
- gr.Markdown("### 📊 Classification Results")
872
- manual_results = gr.Textbox(label="Results Log", interactive=False, lines=10, value="Manual classification results will appear here...")
873
-
874
- # Individual sound previews (no autoplay in manual mode)
875
- gr.Markdown("### 🔊 Sound Preview")
876
- manual_left_hand_sound = gr.Audio(label="👈 Left Hand", interactive=False, autoplay=False, visible=True)
877
- manual_right_hand_sound = gr.Audio(label="👉 Right Hand", interactive=False, autoplay=False, visible=True)
878
- manual_left_leg_sound = gr.Audio(label="🦵 Left Leg", interactive=False, autoplay=False, visible=True)
879
- manual_right_leg_sound = gr.Audio(label="🦵 Right Leg", interactive=False, autoplay=False, visible=True)
880
- #manual_tongue_sound = gr.Audio(label="👅 Tongue", interactive=False, autoplay=False, visible=True)
881
-
882
- # Session management functions
883
- def start_new_session():
884
- """Reset everything and start a completely new rehabilitation session"""
885
- global sound_manager
886
- sound_manager.completed_cycles = 0
887
- sound_manager.current_cycle = 0
888
- sound_manager.movements_completed = set()
889
- sound_manager.composition_layers = []
890
-
891
- # Start fresh session
892
- result = start_automatic_composition()
893
- return (
894
- result[0], # target_display
895
- result[1], # predicted_display
896
- result[2], # timer_display
897
- result[3], # user_prompt
898
- result[4], # eeg_plot
899
- result[5], # left_hand_sound
900
- result[6], # right_hand_sound
901
- result[7], # left_leg_sound
902
- result[8], # right_leg_sound
903
- #result[9], # tongue_sound
904
- result[9], # composition_status
905
- result[10], # building_instructions
906
- result[11], # dj_instructions
907
- gr.update(visible=True), # continue_btn - show it
908
- gr.update(active=True), # timer - activate it
909
- gr.update(visible=False) # session_complete_row - hide it
910
- )
911
-
912
- def extend_current_session():
913
- """Continue current session beyond the 2-cycle limit"""
914
- sound_manager.max_cycles += 2 # Add 2 more cycles
915
-
916
- # Continue with current session
917
- result = continue_automatic_composition()
918
- return (
919
- result[0], # target_display
920
- result[1], # predicted_display
921
- result[2], # timer_display
922
- result[3], # user_prompt
923
- result[4], # eeg_plot
924
- result[5], # left_hand_sound
925
- result[6], # right_hand_sound
926
- result[7], # left_leg_sound
927
- result[8], # right_leg_sound
928
- #result[9], # tongue_sound
929
- result[9], # composition_status
930
- result[10], # building_instructions
931
- result[11], # dj_instructions
932
- gr.update(visible=True), # continue_btn - show it
933
- gr.update(active=True), # timer - activate it
934
- gr.update(visible=False) # session_complete_row - hide it
935
- )
936
 
937
- # Wrapper functions for timer control
938
- def start_with_timer():
939
- """Start composition and activate automatic timer"""
940
- result = start_automatic_composition()
941
- # Show continue button and activate timer
942
- return (
943
- result[0], # target_display
944
- result[1], # predicted_display
945
- result[2], # timer_display
946
- result[3], # user_prompt
947
- result[4], # eeg_plot
948
- result[5], # left_hand_sound
949
- result[6], # right_hand_sound
950
- result[7], # left_leg_sound
951
- result[8], # right_leg_sound
952
- #result[9], # tongue_sound
953
- result[9], # composition_status
954
- result[10], # building_instructions
955
- result[11], # dj_instructions
956
- gr.update(visible=True), # continue_btn - show it
957
- gr.update(active=True) # timer - activate it
958
- )
959
-
960
- def continue_with_timer():
961
- """Continue composition and manage timer state"""
962
- result = continue_automatic_composition()
963
-
964
- # Check if session is complete (rehabilitation session finished)
965
- if "🎉 Session Complete!" in result[0]:
966
- # Show session completion options
967
- return (
968
- result[0], # target_display
969
- result[1], # predicted_display
970
- result[2], # timer_display
971
- result[3], # user_prompt
972
- result[4], # eeg_plot
973
- result[5], # left_hand_sound
974
- result[6], # right_hand_sound
975
- result[7], # left_leg_sound
976
- result[8], # right_leg_sound
977
- #result[9], # tongue_sound
978
- result[9], # composition_status
979
- result[10], # building_instructions
980
- result[11], # dj_instructions
981
- gr.update(active=False), # timer - deactivate it
982
- gr.update(visible=True) # session_complete_row - show options
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
983
  )
984
- # Check if composition is complete (old logic for other cases)
985
- elif "🎵 Cycle Complete!" in result[0]:
986
- # Stop the timer when composition is complete
987
- return (
988
- result[0], # target_display
989
- result[1], # predicted_display
990
- result[2], # timer_display
991
- result[3], # user_prompt
992
- result[4], # eeg_plot
993
- result[5], # left_hand_sound
994
- result[6], # right_hand_sound
995
- result[7], # left_leg_sound
996
- result[8], # right_leg_sound
997
- #result[9], # tongue_sound
998
- result[9], # composition_status
999
- result[10], # building_instructions
1000
- result[11], # dj_instructions
1001
- gr.update(active=False), # timer - deactivate it
1002
- gr.update(visible=False) # session_complete_row - keep hidden
1003
  )
1004
- else:
1005
- # Keep timer active for next iteration
1006
- return (
1007
- result[0], # target_display
1008
- result[1], # predicted_display
1009
- result[2], # timer_display
1010
- result[3], # user_prompt
1011
- result[4], # eeg_plot
1012
- result[5], # left_hand_sound
1013
- result[6], # right_hand_sound
1014
- result[7], # left_leg_sound
1015
- result[8], # right_leg_sound
1016
- #result[9], # tongue_sound
1017
- result[9], # composition_status
1018
- result[10], # building_instructions
1019
- result[11], # dj_instructions
1020
- gr.update(active=True), # timer - keep active
1021
- gr.update(visible=False) # session_complete_row - keep hidden
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1022
  )
1023
-
1024
- # Event handlers for automatic composition tab
1025
- start_btn.click(
1026
- fn=start_with_timer,
1027
- outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
1028
- left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status,
1029
- building_instructions, dj_instructions, continue_btn, timer]
1030
- )
1031
-
1032
- continue_btn.click(
1033
- fn=continue_with_timer,
1034
- outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
1035
- left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status,
1036
- building_instructions, dj_instructions, timer, session_complete_row]
1037
- )
1038
-
1039
- # Timer automatically triggers continuation
1040
- timer.tick(
1041
- fn=continue_with_timer,
1042
- outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
1043
- left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status,
1044
- building_instructions, dj_instructions, timer, session_complete_row]
1045
- )
1046
-
1047
- # Session completion event handlers
1048
- new_session_btn.click(
1049
- fn=start_new_session,
1050
- outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
1051
- left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status,
1052
- building_instructions, dj_instructions, continue_btn, timer, session_complete_row]
1053
- )
1054
-
1055
- extend_session_btn.click(
1056
- fn=extend_current_session,
1057
- outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
1058
- left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status,
1059
- building_instructions, dj_instructions, continue_btn, timer, session_complete_row]
1060
- )
1061
-
1062
- def stop_with_timer():
1063
- """Stop composition and deactivate timer"""
1064
- result = stop_composition()
1065
- return (
1066
- result[0], # target_display
1067
- result[1], # predicted_display
1068
- result[2], # timer_display
1069
- result[3], # user_prompt
1070
- gr.update(visible=False), # continue_btn - hide it
1071
- gr.update(active=False) # timer - deactivate it
1072
- )
1073
-
1074
- stop_btn.click(
1075
- fn=stop_with_timer,
1076
- outputs=[target_display, predicted_display, timer_display, user_prompt, continue_btn, timer]
1077
- )
1078
-
1079
- # Event handlers for manual testing tab
1080
- classify_btn.click(
1081
- fn=manual_classify,
1082
- outputs=[manual_target_display, manual_predicted_display, manual_timer_display, manual_eeg_plot, manual_results,
1083
- manual_left_hand_sound, manual_right_hand_sound, manual_left_leg_sound, manual_right_leg_sound]
1084
- )
1085
-
1086
- clear_btn.click(
1087
- fn=clear_manual,
1088
- outputs=[manual_target_display, manual_predicted_display, manual_timer_display, manual_eeg_plot, manual_results,
1089
- manual_left_hand_sound, manual_right_hand_sound, manual_left_leg_sound, manual_right_leg_sound]
1090
- )
1091
-
1092
- # Note: No auto-loading of sounds to prevent playing all sounds on startup
1093
-
1094
  return demo
1095
 
1096
  if __name__ == "__main__":
1097
  demo = create_interface()
1098
- demo.launch(server_name="0.0.0.0", server_port=7867)
 
1
  """
2
+ EEG Motor Imagery Music Composer - Clean Transition Version
3
+ =========================================================
4
+ This version implements a clear separation between the building phase (layering sounds) and the DJ phase (effect control),
5
+ with seamless playback of all layered sounds throughout both phases.
6
  """
7
 
8
+ # Set matplotlib backend to non-GUI for server/web use
9
+ import matplotlib
10
+ matplotlib.use('Agg') # Set backend BEFORE importing pyplot
11
  import matplotlib.pyplot as plt
 
 
12
  import os
13
+ import gradio as gr
14
+ import numpy as np
15
+ from typing import Dict
16
+ from sound_manager import SoundManager
17
  from data_processor import EEGDataProcessor
18
  from classifier import MotorImageryClassifier
19
+ from config import DEMO_DATA_PATHS, CONFIDENCE_THRESHOLD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # --- Initialization ---
22
  app_state = {
23
  'is_running': False,
24
  'demo_data': None,
25
  'demo_labels': None,
 
26
  'composition_active': False,
27
  'auto_mode': False
28
  }
29
 
30
+ sound_manager = SoundManager()
31
+ data_processor = EEGDataProcessor()
32
+ classifier = MotorImageryClassifier()
 
33
 
34
+ # Load demo data
35
+ existing_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)]
36
+ if existing_files:
37
+ app_state['demo_data'], app_state['demo_labels'] = data_processor.process_files(existing_files)
38
+ else:
39
+ app_state['demo_data'], app_state['demo_labels'] = None, None
40
+
41
+ if app_state['demo_data'] is not None:
42
+ classifier.load_model(n_chans=app_state['demo_data'].shape[1], n_times=app_state['demo_data'].shape[2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # --- Helper Functions ---
45
  def get_movement_sounds() -> Dict[str, str]:
46
  """Get the current sound files for each movement."""
47
  sounds = {}
48
+ from sound_manager import AudioEffectsProcessor
49
+ import tempfile
50
+ import soundfile as sf
51
+ # If in DJ mode, use effect-processed file if effect is ON
52
+ dj_mode = getattr(sound_manager, 'current_phase', None) == 'dj_effects'
53
  for movement, sound_file in sound_manager.current_sound_mapping.items():
54
+ if movement in ['left_hand', 'right_hand', 'left_leg', 'right_leg']:
55
+ if sound_file is not None:
56
  sound_path = sound_manager.sound_dir / sound_file
57
  if sound_path.exists():
58
+ if dj_mode and sound_manager.active_effects.get(movement, False):
59
+ # Load audio, apply effect, save to temp file
60
+ data, sr = sf.read(str(sound_path))
61
+ if len(data.shape) > 1:
62
+ data = np.mean(data, axis=1)
63
+ processed = AudioEffectsProcessor.process_layer_with_effects(
64
+ data, sr, movement, sound_manager.active_effects
65
+ )
66
+ # Save to temp file
67
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
68
+ sf.write(tmp.name, processed, sr)
69
+ print(f"DEBUG: Playing PROCESSED audio for {movement}: {tmp.name}")
70
+ sounds[movement] = tmp.name
71
+ else:
72
+ print(f"DEBUG: Playing ORIGINAL audio for {movement}: {sound_path.resolve()}")
73
+ sounds[movement] = str(sound_path.resolve())
74
  return sounds
75
 
76
+ def create_eeg_plot(eeg_data: np.ndarray, target_movement: str, predicted_name: str, confidence: float, sound_added: bool) -> plt.Figure:
77
+ fig, axes = plt.subplots(1, 2, figsize=(10, 4))
78
+ axes = axes.flatten()
79
+ time_points = np.arange(eeg_data.shape[1]) / 200
80
+ channel_names = ['C3', 'C4']
81
+ for i in range(min(2, eeg_data.shape[0])):
82
+ color = 'green' if sound_added else 'blue'
83
+ axes[i].plot(time_points, eeg_data[i], color=color, linewidth=1)
84
+ axes[i].set_title(f'{channel_names[i] if i < len(channel_names) else f"Channel {i+1}"}')
85
+ axes[i].set_xlabel('Time (s)')
86
+ axes[i].set_ylabel('Amplitude (µV)')
87
+ axes[i].grid(True, alpha=0.3)
88
+ title = f"Target: {target_movement.replace('_', ' ').title()} | Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
89
+ fig.suptitle(title, fontsize=12, fontweight='bold')
90
+ fig.tight_layout()
91
+ plt.close(fig)
92
+ return fig
93
+
94
+ def format_composition_summary(composition_info: Dict) -> str:
95
+ if not composition_info.get('layers_by_cycle'):
96
+ return "No composition layers yet"
97
+ summary = []
98
+ for cycle, layers in composition_info['layers_by_cycle'].items():
99
+ summary.append(f"Cycle {cycle + 1}: {len(layers)} layers")
100
+ for layer in layers:
101
+ movement = layer.get('movement', 'unknown')
102
+ confidence = layer.get('confidence', 0)
103
+ summary.append(f" • {movement.replace('_', ' ').title()} ({confidence:.2f})")
104
+ # DJ Effects Status removed from status tab as requested
105
+ return "\n".join(summary) if summary else "No composition layers"
106
+
107
+ # --- Main Logic ---
108
  def start_composition():
109
+ '''
110
+ Start the composition process.
111
+ '''
112
  global app_state
 
 
113
  if not app_state['composition_active']:
114
  app_state['composition_active'] = True
115
+ sound_manager.start_new_cycle()
116
+ print(f"DEBUG: [start_composition] current_phase={sound_manager.current_phase}, movements_completed={sound_manager.movements_completed}")
117
  if app_state['demo_data'] is None:
118
  return "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available"
119
+ # Force first trial to always be left_hand/instrumental
120
+ if len(sound_manager.movements_completed) == 0:
121
+ next_movement = 'left_hand'
122
+ left_hand_label = [k for k, v in classifier.class_names.items() if v == 'left_hand'][0]
123
+ import numpy as np
124
+ matching_indices = np.where(app_state['demo_labels'] == left_hand_label)[0]
125
+ chosen_idx = np.random.choice(matching_indices)
126
+ epoch_data = app_state['demo_data'][chosen_idx]
127
+ true_label = left_hand_label
128
+ true_label_name = 'left_hand'
129
+ else:
130
+ epoch_data, true_label = data_processor.simulate_real_time_data(app_state['demo_data'], app_state['demo_labels'], mode="class_balanced")
131
+ true_label_name = classifier.class_names[true_label]
132
+ next_movement = sound_manager.get_current_target_movement()
133
+ if next_movement == "cycle_complete":
134
+ print("DEBUG: [start_composition] Transitioning to DJ mode!")
135
+ return continue_dj_phase()
136
  predicted_class, confidence, probabilities = classifier.predict(epoch_data)
137
  predicted_name = classifier.class_names[predicted_class]
138
+ print(f"TRIAL: true_label={true_label_name}, presented_target={next_movement}, predicted={predicted_name}")
139
+ # Only add sound if confidence > threshold, predicted == true label, and true label matches the prompt
140
+ if confidence > CONFIDENCE_THRESHOLD and predicted_name == true_label_name:
141
+ result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD, force_add=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  else:
143
+ result = {'sound_added': False}
144
+ fig = create_eeg_plot(epoch_data, true_label_name, predicted_name, confidence, result['sound_added'])
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # Only play completed movement sounds (layered)
147
+ sounds = get_movement_sounds()
148
+ completed_movements = sound_manager.movements_completed
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ # Assign audio paths only for completed movements
151
+ left_hand_audio = sounds.get('left_hand') if 'left_hand' in completed_movements else None
152
+ right_hand_audio = sounds.get('right_hand') if 'right_hand' in completed_movements else None
153
+ left_leg_audio = sounds.get('left_leg') if 'left_leg' in completed_movements else None
154
+ right_leg_audio = sounds.get('right_leg') if 'right_leg' in completed_movements else None
 
 
 
 
 
 
155
 
156
+ print("DEBUG: movement sound paths:", sounds)
157
+ print("DEBUG: completed movements:", completed_movements)
158
+ print("DEBUG: left_hand_audio:", left_hand_audio, "exists:", os.path.exists(left_hand_audio) if left_hand_audio else None)
159
+ print("DEBUG: right_hand_audio:", right_hand_audio, "exists:", os.path.exists(right_hand_audio) if right_hand_audio else None)
160
+ print("DEBUG: left_leg_audio:", left_leg_audio, "exists:", os.path.exists(left_leg_audio) if left_leg_audio else None)
161
+ print("DEBUG: right_leg_audio:", right_leg_audio, "exists:", os.path.exists(right_leg_audio) if right_leg_audio else None)
162
 
 
 
 
 
 
163
 
 
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ # 2. Movement Commands: show mapping for all movements
167
+ movement_emojis = {
168
+ "left_hand": "👈",
169
+ "right_hand": "👉",
170
+ "left_leg": "🦵",
171
+ "right_leg": "🦵",
172
+ }
173
+ movement_command_lines = []
174
+ for movement in ["left_hand", "right_hand", "left_leg", "right_leg"]:
175
+ sound_file = sound_manager.current_sound_mapping.get(movement, "")
176
+ instrument_type = ""
177
+ for key in ["bass", "drums", "instruments", "vocals"]:
178
+ if key in sound_file.lower():
179
+ instrument_type = key if key != "instruments" else "instrument"
180
+ break
181
+ pretty_movement = movement.replace("_", " ").title()
182
+ pretty_instrument = instrument_type.capitalize() if instrument_type else "--"
183
+ emoji = movement_emojis.get(movement, "")
184
+ movement_command_lines.append(f"{emoji} {pretty_movement}: {pretty_instrument}")
185
+ movement_command_text = "🎼 Composition Mode - Movement to Layers Mapping\n" + "\n".join(movement_command_lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ # 3. Next Trial: will be set dynamically in timer_tick
188
+ next_trial_text = ""
 
 
 
 
 
 
 
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  composition_info = sound_manager.get_composition_info()
191
  status_text = format_composition_summary(composition_info)
 
 
 
 
 
192
  return (
193
+ movement_command_text,
194
+ next_trial_text,
 
 
195
  fig,
196
  left_hand_audio,
197
+ right_hand_audio,
198
  left_leg_audio,
199
  right_leg_audio,
200
+ status_text
 
 
 
201
  )
202
 
203
+ def continue_dj_phase():
 
204
  global app_state
205
+ print(f"DEBUG: [continue_dj_phase] Entered DJ mode. current_phase={sound_manager.current_phase}")
206
  if not app_state['composition_active']:
207
  return "❌ Not active", "❌ Not active", "❌ Not active", None, None, None, None, None, None, "Click 'Start Composing' first"
 
208
  if app_state['demo_data'] is None:
209
  return "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available"
210
+ # DJ phase: classify and apply effects, but always play all layered sounds
211
+ epoch_data, true_label = data_processor.simulate_real_time_data(app_state['demo_data'], app_state['demo_labels'], mode="class_balanced")
 
 
 
 
 
 
 
 
 
 
 
 
212
  predicted_class, confidence, probabilities = classifier.predict(epoch_data)
213
  predicted_name = classifier.class_names[predicted_class]
214
+ # Toggle effect if confidence is high
215
+ if confidence > CONFIDENCE_THRESHOLD:
216
+ print(f"DEBUG: [continue_dj_phase] Toggling DJ effect for movement: {predicted_name}")
217
+ sound_manager.toggle_dj_effect(predicted_name, brief=True, duration=1.0)
218
+ true_label_name = classifier.class_names[true_label]
219
+ # Only turn plot green if effect is actually toggled (applied)
220
+ effect_applied = False
221
+ if confidence > CONFIDENCE_THRESHOLD:
222
+ result = sound_manager.toggle_dj_effect(predicted_name, brief=True, duration=1.0)
223
+ effect_applied = result.get("effect_applied", False)
224
+ else:
225
+ result = None
226
+ fig = create_eeg_plot(epoch_data, true_label_name, predicted_name, confidence, effect_applied)
227
+ # Always play all completed movement sounds (layered)
 
 
 
 
 
 
 
228
  sounds = get_movement_sounds()
229
  completed_movements = sound_manager.movements_completed
230
+ left_hand_audio = sounds.get('left_hand') if 'left_hand' in completed_movements else None
231
+ right_hand_audio = sounds.get('right_hand') if 'right_hand' in completed_movements else None
232
+ left_leg_audio = sounds.get('left_leg') if 'left_leg' in completed_movements else None
233
+ right_leg_audio = sounds.get('right_leg') if 'right_leg' in completed_movements else None
234
+ # Show DJ effect mapping for each movement with ON/OFF status and correct instrument mapping
235
+ movement_map = {
236
+ "left_hand": {"effect": "Echo", "instrument": "Instrument"},
237
+ "right_hand": {"effect": "Low Pass", "instrument": "Bass"},
238
+ "left_leg": {"effect": "Compressor", "instrument": "Drums"},
239
+ "right_leg": {"effect": "High Pass", "instrument": "Vocals"},
240
+ }
241
+ emoji_map = {"left_hand": "👈", "right_hand": "👉", "left_leg": "🦵", "right_leg": "🦵"}
242
+ # Get effect ON/OFF status from sound_manager.active_effects
243
+ movement_command_lines = []
244
+ for m in ["left_hand", "right_hand", "left_leg", "right_leg"]:
245
+ status = "ON" if sound_manager.active_effects.get(m, False) else "off"
246
+ movement_command_lines.append(f"{emoji_map[m]} {m.replace('_', ' ').title()}: {movement_map[m]['effect']} [{status}] → {movement_map[m]['instrument']}")
247
+ target_text = "🎧 DJ Mode - Movement to Effect Mapping\n" + "\n".join(movement_command_lines)
248
+ # In DJ mode, Next Trial should only show the prompt, not the predicted/target movement
249
+ predicted_text = "Imagine next movement"
250
  composition_info = sound_manager.get_composition_info()
251
  status_text = format_composition_summary(composition_info)
252
+ # Ensure exactly 10 outputs: [textbox, textbox, plot, audio, audio, audio, audio, textbox, timer, button]
253
+ # Use fig for the plot, and fill all outputs with correct types
254
  return (
255
+ target_text, # Movement Commands (textbox)
256
+ predicted_text, # Next Trial (textbox)
257
+ fig, # EEG Plot (plot)
258
+ left_hand_audio, # Left Hand (audio)
259
+ right_hand_audio, # Right Hand (audio)
260
+ left_leg_audio, # Left Leg (audio)
261
+ right_leg_audio, # Right Leg (audio)
262
+ status_text, # Composition Status (textbox)
263
+ gr.update(), # Timer (update object)
264
+ gr.update() # Continue DJ Button (update object)
265
  )
266
 
267
+ # --- Gradio UI ---
268
+ def create_interface():
269
+ with gr.Blocks(title="EEG Motor Imagery Music Composer", theme=gr.themes.Citrus()) as demo:
270
+ with gr.Tabs():
271
+ with gr.TabItem("Automatic Music Composition"):
272
+ gr.Markdown("# 🧠🎵 EEG Motor Imagery Rehabilitation Composer")
273
+ #gr.Markdown("**Therapeutic Brain-Computer Interface for Motor Recovery**\n\nCreate beautiful music compositions using your brain signals! This rehabilitation tool helps strengthen motor imagery skills while creating personalized musical pieces.")
274
+ gr.Markdown("""
275
+ **How the Task Works**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
+ This app has **two stages**:
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ 1. **Music Composition Stage**: Use your motor imagery (imagine moving your left hand, right hand, left leg, or right leg) to add musical layers. Each correct, high-confidence brain signal prediction adds a new sound to your composition. The system will prompt you with a movement to imagine, and you should focus on that movement until the next prompt.
280
+
281
+ 2. **DJ Effects Stage**: After all four movements are completed, you enter DJ mode. Here, you can apply effects and control playback of your own composition using new commands. The interface and available controls will change to let you experiment with your music.
282
+
283
+ > **Note:** In DJ mode, each effect is only triggered every 4th time you perform the same movement. This prevents tracks from reloading too frequently.
284
+
285
+ **Commands and controls will change between stages.** Follow the on-screen instructions for each phase.
286
+ """)
287
+
288
  with gr.Row():
 
289
  with gr.Column(scale=2):
290
+ start_btn = gr.Button("🎵 Start Composing", variant="primary", size="lg")
291
+ stop_btn = gr.Button("🛑 Stop", variant="stop", size="md")
292
+ continue_btn = gr.Button("⏭️ Continue DJ Phase", variant="primary", size="lg", visible=False)
293
+ timer = gr.Timer(value=1.0, active=False) # 4 second intervals
294
+ predicted_display = gr.Textbox(label="🧠 Movement Commands", interactive=False, value="--", lines=4)
295
+ timer_display = gr.Textbox(label="⏱️ Next Trial", interactive=False, value="--")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  eeg_plot = gr.Plot(label="EEG Data Visualization")
 
 
297
  with gr.Column(scale=1):
 
 
 
298
  left_hand_sound = gr.Audio(label="👈 Left Hand", interactive=False, autoplay=True, visible=True)
299
+ right_hand_sound = gr.Audio(label="👉 Right Hand", interactive=False, autoplay=True, visible=True)
300
  left_leg_sound = gr.Audio(label="🦵 Left Leg", interactive=False, autoplay=True, visible=True)
301
  right_leg_sound = gr.Audio(label="🦵 Right Leg", interactive=False, autoplay=True, visible=True)
 
 
 
302
  composition_status = gr.Textbox(label="Composition Status", interactive=False, lines=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ def start_and_activate_timer():
305
+ result = start_composition()
306
+ last_trial_result[:] = result # Initialize with first trial result
307
+ if "DJ Mode" not in result[0]:
308
+ return (*result, gr.update(active=True), gr.update(visible=False))
309
+ else:
310
+ return (*result, gr.update(active=False), gr.update(visible=True))
311
+
312
+ # ITI logic: 3s blank, 1s prompt, then trial
313
+ timer_counter = {"count": 0}
314
+ last_trial_result = [None] * 9 # Adjust length to match your outputs
315
+ def timer_tick():
316
+ # 0,1,2: blank, 3: prompt, 4: trial
317
+ if timer_counter["count"] < 3:
318
+ timer_counter["count"] += 1
319
+ # Show blank prompt, keep last outputs
320
+ if len(last_trial_result) == 8:
321
+ return (*last_trial_result, gr.update(active=True), gr.update(visible=False))
322
+ elif len(last_trial_result) == 10:
323
+ # DJ mode: blank prompt
324
+ result = list(last_trial_result)
325
+ result[1] = ""
326
+ return tuple(result)
327
+ else:
328
+ raise ValueError(f"Unexpected last_trial_result length: {len(last_trial_result)}")
329
+ elif timer_counter["count"] == 3:
330
+ timer_counter["count"] += 1
331
+ # Show prompt
332
+ result = list(last_trial_result)
333
+ result[1] = "Imagine next movement"
334
+ if len(result) == 8:
335
+ return (*result, gr.update(active=True), gr.update(visible=False))
336
+ elif len(result) == 10:
337
+ return tuple(result)
338
+ else:
339
+ raise ValueError(f"Unexpected result length in prompt: {len(result)}")
340
+ else:
341
+ timer_counter["count"] = 0
342
+ # Run trial
343
+ result = list(start_composition())
344
+ last_trial_result[:] = result # Save for next blanks/prompts
345
+ if len(result) == 8:
346
+ # Pre-DJ mode: add timer and button updates
347
+ if any(isinstance(x, str) and "DJ Mode" in x for x in result):
348
+ print("DEBUG: [timer_tick] DJ mode detected in outputs, stopping timer and showing continue button.")
349
+ return (*result, gr.update(active=False), gr.update(visible=True))
350
+ else:
351
+ print("DEBUG: [timer_tick] Not in DJ mode, continuing trials.")
352
+ return (*result, gr.update(active=True), gr.update(visible=False))
353
+ elif len(result) == 10:
354
+ print("DEBUG: [timer_tick] Already in DJ mode, returning result as is.")
355
+ return tuple(result)
356
+ else:
357
+ raise ValueError(f"Unexpected result length in timer_tick: {len(result)}")
358
+
359
+ def continue_dj():
360
+ result = continue_dj_phase()
361
+ if len(result) == 8:
362
+ return (*result, gr.update(active=False), gr.update(visible=True))
363
+ elif len(result) == 10:
364
+ return result
365
+ else:
366
+ raise ValueError(f"Unexpected result length in continue_dj: {len(result)}")
367
+ start_btn.click(
368
+ fn=start_and_activate_timer,
369
+ outputs=[predicted_display, timer_display, eeg_plot,
370
+ left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status, timer, continue_btn]
371
+ )
372
+ timer_event = timer.tick(
373
+ fn=timer_tick,
374
+ outputs=[predicted_display, timer_display, eeg_plot,
375
+ left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status, timer, continue_btn]
376
+ )
377
+ def stop_composing():
378
+ timer_counter["count"] = 0
379
+ last_trial_result[:] = ["--"] * 9
380
+ app_state['composition_active'] = False # Ensure new cycle on next start
381
+ # Clear UI and deactivate timer, hide continue button
382
+ return ("--", "Stopped", None, None, None, None, None, "Stopped", gr.update(active=False), gr.update(visible=False))
383
+
384
+ stop_btn.click(
385
+ fn=stop_composing,
386
+ outputs=[predicted_display, timer_display, eeg_plot,
387
+ left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status, timer, continue_btn],
388
+ cancels=[timer_event]
389
  )
390
+ continue_btn.click(
391
+ fn=continue_dj,
392
+ outputs=[predicted_display, timer_display, eeg_plot,
393
+ left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status, timer, continue_btn]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  )
395
+
396
+ with gr.TabItem("Manual Classifier"):
397
+ gr.Markdown("# 🧑‍💻 Manual Classifier Test")
398
+ gr.Markdown("Select a movement and run the classifier manually on a random epoch for that movement. Results will be accumulated below.")
399
+ movement_dropdown = gr.Dropdown(choices=["left_hand", "right_hand", "left_leg", "right_leg"], label="Select Movement")
400
+ manual_btn = gr.Button("Run Classifier", variant="primary")
401
+ manual_predicted = gr.Textbox(label="Predicted Class", interactive=False)
402
+ manual_confidence = gr.Textbox(label="Confidence", interactive=False)
403
+ manual_plot = gr.Plot(label="EEG Data Visualization")
404
+ manual_probs = gr.Plot(label="Class Probabilities")
405
+ manual_confmat = gr.Plot(label="Confusion Matrix (Session)")
406
+
407
+ # Session state for confusion matrix
408
+ from collections import defaultdict
409
+ session_confmat = defaultdict(lambda: defaultdict(int))
410
+
411
+ def manual_classify(selected_movement):
412
+ import matplotlib.pyplot as plt
413
+ import numpy as np
414
+ if app_state['demo_data'] is None or app_state['demo_labels'] is None:
415
+ return "No data", "No data", None, None, None
416
+ label_idx = [k for k, v in classifier.class_names.items() if v == selected_movement][0]
417
+ matching_indices = np.where(app_state['demo_labels'] == label_idx)[0]
418
+ if len(matching_indices) == 0:
419
+ return "No data for this movement", "", None, None, None
420
+ chosen_idx = np.random.choice(matching_indices)
421
+ epoch_data = app_state['demo_data'][chosen_idx]
422
+ predicted_class, confidence, probs = classifier.predict(epoch_data)
423
+ predicted_name = classifier.class_names[predicted_class]
424
+ # Update confusion matrix
425
+ session_confmat[selected_movement][predicted_name] += 1
426
+ # Plot confusion matrix
427
+ classes = ["left_hand", "right_hand", "left_leg", "right_leg"]
428
+ confmat = np.zeros((4, 4), dtype=int)
429
+ for i, true_m in enumerate(classes):
430
+ for j, pred_m in enumerate(classes):
431
+ confmat[i, j] = session_confmat[true_m][pred_m]
432
+ fig_confmat, ax = plt.subplots(figsize=(4, 4))
433
+ ax.imshow(confmat, cmap="Blues")
434
+ ax.set_xticks(np.arange(4))
435
+ ax.set_yticks(np.arange(4))
436
+ ax.set_xticklabels(classes, rotation=45, ha="right")
437
+ ax.set_yticklabels(classes)
438
+ ax.set_xlabel("Predicted")
439
+ ax.set_ylabel("True")
440
+ for i in range(4):
441
+ for j in range(4):
442
+ ax.text(j, i, str(confmat[i, j]), ha="center", va="center", color="black")
443
+ fig_confmat.tight_layout()
444
+ # Plot class probabilities
445
+ if isinstance(probs, dict):
446
+ probs_list = [probs.get(cls, 0.0) for cls in classes]
447
+ else:
448
+ probs_list = list(probs)
449
+ fig_probs, ax_probs = plt.subplots(figsize=(4, 2))
450
+ ax_probs.bar(classes, probs_list)
451
+ ax_probs.set_ylabel("Probability")
452
+ ax_probs.set_ylim(0, 1)
453
+ fig_probs.tight_layout()
454
+ # EEG plot
455
+ fig = create_eeg_plot(epoch_data, selected_movement, predicted_name, confidence, False)
456
+ # Close all open figures to avoid warnings
457
+ plt.close(fig_confmat)
458
+ plt.close(fig_probs)
459
+ plt.close(fig)
460
+ return predicted_name, f"{confidence:.2f}", fig, fig_probs, fig_confmat
461
+
462
+ manual_btn.click(
463
+ fn=manual_classify,
464
+ inputs=[movement_dropdown],
465
+ outputs=[manual_predicted, manual_confidence, manual_plot, manual_probs, manual_confmat]
466
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  return demo
468
 
469
  if __name__ == "__main__":
470
  demo = create_interface()
471
+ demo.launch(server_name="0.0.0.0", server_port=7867)
classifier.py CHANGED
@@ -8,7 +8,8 @@ Based on the ShallowFBCSPNet architecture from the original eeg_motor_imagery.py
8
  import torch
9
  import torch.nn as nn
10
  import numpy as np
11
- from braindecode.models import ShallowFBCSPNet
 
12
  from typing import Dict, Tuple, Optional
13
  import os
14
  from sklearn.metrics import accuracy_score
@@ -20,7 +21,7 @@ class MotorImageryClassifier:
20
  Motor imagery classifier using ShallowFBCSPNet model.
21
  """
22
 
23
- def __init__(self, model_path: str = "shallow_weights_all.pth"):
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  self.model = None
26
  self.model_path = model_path
@@ -46,8 +47,27 @@ class MotorImageryClassifier:
46
 
47
  if os.path.exists(self.model_path):
48
  try:
49
- state_dict = torch.load(self.model_path, map_location=self.device)
50
- self.model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  self.model.eval()
52
  self.is_loaded = True
53
  print(f"✅ Pre-trained model loaded successfully from {self.model_path}")
 
8
  import torch
9
  import torch.nn as nn
10
  import numpy as np
11
+ from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
12
+ from braindecode.modules.layers import Ensure4d # necessary for loading
13
  from typing import Dict, Tuple, Optional
14
  import os
15
  from sklearn.metrics import accuracy_score
 
21
  Motor imagery classifier using ShallowFBCSPNet model.
22
  """
23
 
24
+ def __init__(self, model_path: str = "model.pth"):
25
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
  self.model = None
27
  self.model_path = model_path
 
47
 
48
  if os.path.exists(self.model_path):
49
  try:
50
+ # Load only the state_dict, using weights_only=True and allowlist ShallowFBCSPNet
51
+ with torch.serialization.safe_globals([Ensure4d, ShallowFBCSPNet]):
52
+ checkpoint = torch.load(
53
+ self.model_path,
54
+ map_location=self.device,
55
+ weights_only=False # must be False to allow objects
56
+ )
57
+
58
+ # If checkpoint is a state_dict (dict of tensors)
59
+ if isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()):
60
+ self.model.load_state_dict(checkpoint)
61
+
62
+ # If checkpoint is the full model object
63
+ elif isinstance(checkpoint, ShallowFBCSPNet):
64
+ self.model = checkpoint.to(self.device)
65
+
66
+ else:
67
+ raise ValueError("Unknown checkpoint format")
68
+
69
+
70
+ #self.model.load_state_dict(state_dict)
71
  self.model.eval()
72
  self.is_loaded = True
73
  print(f"✅ Pre-trained model loaded successfully from {self.model_path}")
config.py CHANGED
@@ -16,7 +16,7 @@ SOUND_DIR = BASE_DIR / "sounds"
16
  MODEL_DIR = BASE_DIR
17
 
18
  # Model settings
19
- MODEL_PATH = MODEL_DIR / "shallow_weights_all.pth"
20
  # Model architecture: Always uses ShallowFBCSPNet from braindecode
21
  # If pre-trained weights not found, will train using LOSO on available data
22
 
 
16
  MODEL_DIR = BASE_DIR
17
 
18
  # Model settings
19
+ MODEL_PATH = MODEL_DIR / "model.pth"
20
  # Model architecture: Always uses ShallowFBCSPNet from braindecode
21
  # If pre-trained weights not found, will train using LOSO on available data
22
 
data_processor.py CHANGED
@@ -81,29 +81,35 @@ class EEGDataProcessor:
81
  return events
82
 
83
  def create_epochs(self, raw: mne.io.RawArray, events: np.ndarray,
84
- tmin: float = 0, tmax: float = 1.5) -> mne.Epochs:
85
  """Create epochs from raw data and events."""
 
 
86
  epochs = mne.Epochs(
87
  raw,
88
  events=events,
89
- event_id=self.event_id,
90
  tmin=tmin,
91
  tmax=tmax,
92
  baseline=None,
93
  preload=True,
94
  )
95
-
96
  return epochs
97
 
98
  def process_files(self, file_paths: List[str]) -> Tuple[np.ndarray, np.ndarray]:
99
  """Process multiple EEG files and return combined data."""
100
  all_epochs = []
101
-
 
 
102
  for file_path in file_paths:
103
  signals, labels, channels, fs = self.load_mat_file(file_path)
104
  raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
105
  events = self.extract_events(labels)
106
- epochs = self.create_epochs(raw, events)
 
 
 
107
  all_epochs.append(epochs)
108
 
109
  if len(all_epochs) > 1:
 
81
  return events
82
 
83
  def create_epochs(self, raw: mne.io.RawArray, events: np.ndarray,
84
+ tmin: float = 0, tmax: float = 1.5, event_id=None) -> mne.Epochs:
85
  """Create epochs from raw data and events."""
86
+ if event_id is None:
87
+ event_id = self.event_id
88
  epochs = mne.Epochs(
89
  raw,
90
  events=events,
91
+ event_id=event_id,
92
  tmin=tmin,
93
  tmax=tmax,
94
  baseline=None,
95
  preload=True,
96
  )
 
97
  return epochs
98
 
99
  def process_files(self, file_paths: List[str]) -> Tuple[np.ndarray, np.ndarray]:
100
  """Process multiple EEG files and return combined data."""
101
  all_epochs = []
102
+ allowed_labels = {1, 2, 4, 6}
103
+ allowed_event_id = {k: v for k, v in self.event_id.items() if v in allowed_labels}
104
+
105
  for file_path in file_paths:
106
  signals, labels, channels, fs = self.load_mat_file(file_path)
107
  raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
108
  events = self.extract_events(labels)
109
+ # only keep allowed labels
110
+ events = events[np.isin(events[:, -1], list(allowed_labels))]
111
+ # create epochs only for allowed labels
112
+ epochs = self.create_epochs(raw, events, event_id=allowed_event_id)
113
  all_epochs.append(epochs)
114
 
115
  if len(all_epochs) > 1:
sound_library.py DELETED
@@ -1,803 +0,0 @@
1
- """
2
- Sound Management System for EEG Motor Imagery Classification
3
- -----------------------------------------------------------
4
- Handles sound mapping, layering, and music composition based on motor imagery predictions.
5
- Uses local SoundHelix audio files for different motor imagery classes.
6
- """
7
-
8
- import numpy as np
9
- import soundfile as sf
10
- import os
11
- from typing import Dict, List, Optional, Tuple
12
- import threading
13
- import time
14
- from pathlib import Path
15
- import tempfile
16
-
17
- # Professional audio effects processor using scipy and librosa
18
- from scipy import signal
19
- import librosa
20
-
21
- class AudioEffectsProcessor:
22
- @staticmethod
23
- def apply_echo(data: np.ndarray, samplerate: int, delay_time: float = 0.3, feedback: float = 0.4) -> np.ndarray:
24
- """Echo/delay effect (tempo-sync if delay_time is set to fraction of beat)."""
25
- try:
26
- delay_samples = int(delay_time * samplerate)
27
- echo_data = np.copy(data)
28
- for i in range(delay_samples, len(data)):
29
- echo_data[i] += feedback * echo_data[i - delay_samples]
30
- return 0.7 * data + 0.3 * echo_data
31
- except Exception as e:
32
- print(f"Echo failed: {e}")
33
- return data
34
- """Professional audio effects for DJ mode using scipy and librosa."""
35
-
36
- @staticmethod
37
- def apply_high_pass_filter(data: np.ndarray, samplerate: int, cutoff: float = 800.0) -> np.ndarray:
38
- """Apply high-pass filter to emphasize highs and cut lows."""
39
- try:
40
- # Design butterworth high-pass filter
41
- nyquist = samplerate / 2
42
- normalized_cutoff = cutoff / nyquist
43
- b, a = signal.butter(4, normalized_cutoff, btype='high', analog=False)
44
-
45
- # Apply filter
46
- filtered_data = signal.filtfilt(b, a, data)
47
- return filtered_data
48
- except Exception as e:
49
- print(f"High-pass filter failed: {e}")
50
- return data
51
-
52
- @staticmethod
53
- def apply_low_pass_filter(data: np.ndarray, samplerate: int, cutoff: float = 1200.0) -> np.ndarray:
54
- """Apply low-pass filter to emphasize lows and cut highs."""
55
- try:
56
- # Design butterworth low-pass filter
57
- nyquist = samplerate / 2
58
- normalized_cutoff = cutoff / nyquist
59
- b, a = signal.butter(4, normalized_cutoff, btype='low', analog=False)
60
-
61
- # Apply filter
62
- filtered_data = signal.filtfilt(b, a, data)
63
- return filtered_data
64
- except Exception as e:
65
- print(f"Low-pass filter failed: {e}")
66
- return data
67
-
68
- @staticmethod
69
- def apply_reverb(data: np.ndarray, samplerate: int, room_size: float = 0.5) -> np.ndarray:
70
- """Apply simple reverb effect using delay and feedback."""
71
- try:
72
- # Simple reverb using multiple delayed copies
73
- delay_samples = int(0.08 * samplerate) # 80ms delay
74
- decay = 0.4 * room_size
75
-
76
- # Create reverb buffer
77
- reverb_data = np.copy(data)
78
-
79
- # Add delayed copies with decay
80
- for i in range(3):
81
- delay = delay_samples * (i + 1)
82
- if delay < len(data):
83
- gain = decay ** (i + 1)
84
- reverb_data[delay:] += data[:-delay] * gain
85
-
86
- # Mix original with reverb
87
- return 0.7 * data + 0.3 * reverb_data
88
- except Exception as e:
89
- print(f"Reverb effect failed: {e}")
90
- return data
91
-
92
- @staticmethod
93
- def apply_echo(data: np.ndarray, samplerate: int, delay_time: float = 0.3, feedback: float = 0.4) -> np.ndarray:
94
- """Echo/delay effect (tempo-sync if delay_time is set to fraction of beat)."""
95
- try:
96
- delay_samples = int(delay_time * samplerate)
97
- echo_data = np.copy(data)
98
- for i in range(delay_samples, len(data)):
99
- echo_data[i] += feedback * echo_data[i - delay_samples]
100
- return 0.7 * data + 0.3 * echo_data
101
- except Exception as e:
102
- print(f"Echo failed: {e}")
103
- return data
104
-
105
- @staticmethod
106
- def apply_bass_boost(data: np.ndarray, samplerate: int, boost_db: float = 6.0) -> np.ndarray:
107
- """Apply bass boost using low-frequency shelving filter."""
108
- try:
109
- # Design low-shelf filter for bass boost
110
- freq = 200.0 # Bass frequency cutoff
111
- nyquist = samplerate / 2
112
- normalized_freq = freq / nyquist
113
-
114
- # Convert dB to linear gain
115
- gain = 10 ** (boost_db / 20)
116
-
117
- # Simple bass boost: amplify low frequencies
118
- b, a = signal.butter(2, normalized_freq, btype='low', analog=False)
119
- low_freq = signal.filtfilt(b, a, data)
120
-
121
- # Mix boosted lows with original
122
- boosted_data = data + (low_freq * (gain - 1))
123
-
124
- # Normalize to prevent clipping
125
- max_val = np.max(np.abs(boosted_data))
126
- if max_val > 0.95:
127
- boosted_data = boosted_data * 0.95 / max_val
128
-
129
- return boosted_data
130
- except Exception as e:
131
- print(f"Bass boost failed: {e}")
132
- return data
133
-
134
- # --- DJ MODE WRAPPER ---
135
- @staticmethod
136
- def process_layer_with_effects(audio_file: str, movement: str, active_effects: Dict[str, bool]) -> str:
137
- """Process a single layer with its corresponding DJ effect if active."""
138
- try:
139
- if not audio_file or not os.path.exists(audio_file):
140
- print(f"Invalid audio file: {audio_file}")
141
- return audio_file
142
- data, samplerate = sf.read(audio_file)
143
- if len(data.shape) > 1:
144
- data = np.mean(data, axis=1)
145
- processed_data = np.copy(data)
146
-
147
- # Map movement to effect
148
- effect_map = {
149
- "left_hand": AudioEffectsProcessor.apply_echo,
150
- "right_hand": AudioEffectsProcessor.apply_high_pass_filter,
151
- "left_leg": AudioEffectsProcessor.apply_reverb,
152
- "right_leg": AudioEffectsProcessor.apply_low_pass_filter,
153
- }
154
- effect_names = {
155
- "left_hand": "echo",
156
- "right_hand": "hpf",
157
- "left_leg": "rev",
158
- "right_leg": "lpf",
159
- }
160
- effect_func = effect_map.get(movement)
161
- effect_name = effect_names.get(movement, "clean")
162
- if active_effects.get(movement, False) and effect_func:
163
- processed_data = effect_func(processed_data, samplerate)
164
- suffix = f"_fx_{effect_name}"
165
- else:
166
- suffix = "_fx_clean"
167
-
168
- base_name = os.path.splitext(audio_file)[0]
169
- processed_file = f"{base_name}{suffix}.wav"
170
- try:
171
- sf.write(processed_file, processed_data, samplerate)
172
- print(f"🎛️ Layer processed: {os.path.basename(processed_file)}")
173
- return os.path.abspath(processed_file)
174
- except Exception as e:
175
- print(f"Failed to save processed layer: {e}")
176
- return os.path.abspath(audio_file)
177
- except Exception as e:
178
- print(f"Layer processing failed: {e}")
179
- return os.path.abspath(audio_file) if audio_file else None
180
-
181
- class SoundManager:
182
- """
183
- Manages cyclic sound composition for motor imagery classification.
184
- Supports full-cycle composition with user-customizable movement-sound mappings.
185
- """
186
-
187
- def __init__(self, sound_dir: str = "sounds", include_neutral_in_cycle: bool = False):
188
- # Available sound files (define FIRST)
189
- self.available_sounds = [
190
- "1_SoundHelix-Song-6_(Bass).wav",
191
- "1_SoundHelix-Song-6_(Drums).wav",
192
- "1_SoundHelix-Song-6_(Other).wav",
193
- "1_SoundHelix-Song-6_(Vocals).wav"
194
- ]
195
-
196
- self.sound_dir = Path(sound_dir)
197
- self.include_neutral_in_cycle = include_neutral_in_cycle
198
-
199
- # Composition state
200
- self.composition_layers = [] # All layers across all cycles
201
- self.current_cycle = 0
202
- self.current_step = 0 # Current step within cycle (0-5)
203
- self.cycle_complete = False
204
- self.completed_cycles = 0 # Track completed cycles for session management
205
- self.max_cycles = 2 # Rehabilitation session limit
206
-
207
- # DJ Effects phase management
208
- self.current_phase = "building" # "building" or "dj_effects"
209
- # self.mixed_composition_file = None # Path to current mixed composition
210
- self.active_effects = { # Track which effects are currently active
211
- "left_hand": False, # Volume fade
212
- "right_hand": False, # Filter sweep
213
- "left_leg": False, # Reverb
214
- "right_leg": False, # Tempo modulation
215
- #"tongue": False # Bass boost
216
- }
217
-
218
- # All possible movements (neutral is optional for composition)
219
- self.all_movements = ["left_hand", "right_hand", "neutral", "left_leg", "tongue", "right_leg"]
220
-
221
- # Active movements that contribute to composition (excluding neutral and tongue)
222
- self.active_movements = ["left_hand", "right_hand", "left_leg", "right_leg"]
223
-
224
- # Current cycle's random movement sequence (shuffled each cycle)
225
- self.current_movement_sequence = []
226
- self.movements_completed = set() # Track which movements have been successfully completed
227
- self._generate_new_sequence()
228
-
229
- # User-customizable sound mapping (can be updated after each cycle)
230
- # Assign each movement a unique sound file from available_sounds
231
- import random
232
- movements = ["left_hand", "right_hand", "left_leg", "right_leg"]
233
- sounds = self.available_sounds.copy()
234
- random.shuffle(sounds)
235
- self.current_sound_mapping = {movement: sound for movement, sound in zip(movements, sounds)}
236
- self.current_sound_mapping["neutral"] = None # No sound for neutral/rest state
237
-
238
- # Available sound files
239
- self.available_sounds = [
240
- "1_SoundHelix-Song-6_(Bass).wav",
241
- "1_SoundHelix-Song-6_(Drums).wav",
242
- "1_SoundHelix-Song-6_(Other).wav",
243
- "1_SoundHelix-Song-6_(Vocals).wav"
244
- ]
245
-
246
- # Load sound files
247
- self.loaded_sounds = {}
248
- self._load_sound_files()
249
-
250
- # Cycle statistics
251
- self.cycle_stats = {
252
- 'total_cycles': 0,
253
- 'successful_classifications': 0,
254
- 'total_attempts': 0
255
- }
256
-
257
- def _load_sound_files(self):
258
- """Load all available sound files into memory."""
259
- for class_name, filename in self.current_sound_mapping.items():
260
- if filename is None:
261
- continue
262
-
263
- file_path = self.sound_dir / filename
264
- if file_path.exists():
265
- try:
266
- data, sample_rate = sf.read(str(file_path))
267
- self.loaded_sounds[class_name] = {
268
- 'data': data,
269
- 'sample_rate': sample_rate,
270
- 'file_path': str(file_path)
271
- }
272
- print(f"Loaded sound for {class_name}: {filename}")
273
- except Exception as e:
274
- print(f"Error loading {filename}: {e}")
275
- else:
276
- print(f"Sound file not found: {file_path}")
277
-
278
- def get_sound_for_class(self, class_name: str) -> Optional[Dict]:
279
- """Get sound data for a specific motor imagery class."""
280
- return self.loaded_sounds.get(class_name)
281
-
282
- def _generate_new_sequence(self):
283
- """Generate a new random sequence of movements for the current cycle."""
284
- import random
285
- # Choose which movements to include based on configuration
286
- # Always use exactly 5 movements per cycle
287
- movements_for_cycle = self.active_movements.copy()
288
- random.shuffle(movements_for_cycle)
289
- self.current_movement_sequence = movements_for_cycle[:5]
290
- self.movements_completed = set()
291
- cycle_size = len(self.current_movement_sequence)
292
- print(f"🎯 New random sequence ({cycle_size} movements): {' → '.join([m.replace('_', ' ').title() for m in self.current_movement_sequence])}")
293
- print(f"DEBUG: Movement sequence length = {len(self.current_movement_sequence)}")
294
-
295
- def get_current_target_movement(self) -> str:
296
- """Get the current movement the user should imagine."""
297
- if self.current_step < len(self.current_movement_sequence):
298
- return self.current_movement_sequence[self.current_step]
299
- return "cycle_complete"
300
-
301
- def get_next_random_movement(self) -> str:
302
- """Get a random movement from those not yet completed in this cycle."""
303
- # Use the same movement set as the current cycle
304
- cycle_movements = self.all_movements if self.include_neutral_in_cycle else self.active_movements
305
- remaining_movements = [m for m in cycle_movements if m not in self.movements_completed]
306
- if not remaining_movements:
307
- return "cycle_complete"
308
-
309
- import random
310
- return random.choice(remaining_movements)
311
-
312
- def process_classification(self, predicted_class: str, confidence: float, threshold: float = 0.7) -> Dict:
313
- """
314
- Process a classification result in the context of the current cycle.
315
- Uses random movement prompting - user can choose any movement at any time.
316
-
317
- Args:
318
- predicted_class: The predicted motor imagery class
319
- confidence: Confidence score
320
- threshold: Minimum confidence to add sound layer
321
-
322
- Returns:
323
- Dictionary with processing results and next action
324
- """
325
- self.cycle_stats['total_attempts'] += 1
326
-
327
- # Check if this movement was already completed in this cycle
328
- already_completed = predicted_class in self.movements_completed
329
-
330
- result = {
331
- 'predicted_class': predicted_class,
332
- 'confidence': confidence,
333
- 'above_threshold': confidence >= threshold,
334
- 'already_completed': already_completed,
335
- 'sound_added': False,
336
- 'cycle_complete': False,
337
- 'next_action': 'continue',
338
- 'movements_remaining': len(self.current_movement_sequence) - len(self.movements_completed),
339
- 'movements_completed': list(self.movements_completed)
340
- }
341
-
342
- # Check if prediction is above threshold and not already completed
343
- if result['above_threshold'] and not already_completed:
344
- # Get sound file for this movement
345
- sound_file = self.current_sound_mapping.get(predicted_class)
346
-
347
- if sound_file is None:
348
- # No sound for this movement (e.g., neutral), but still count as completed
349
- self.movements_completed.add(predicted_class)
350
- result['sound_added'] = False # No sound, but movement completed
351
- self.cycle_stats['successful_classifications'] += 1
352
- cycle_size = len(self.current_movement_sequence)
353
- print(f"✓ Cycle {self.current_cycle+1}: {predicted_class} completed (no sound) ({len(self.movements_completed)}/{cycle_size} complete)")
354
- elif predicted_class in self.loaded_sounds:
355
- # Add sound layer
356
- layer_info = {
357
- 'cycle': self.current_cycle,
358
- 'step': len(self.movements_completed), # Step based on number completed
359
- 'movement': predicted_class,
360
- 'sound_file': sound_file,
361
- 'confidence': confidence,
362
- 'timestamp': time.time(),
363
- 'sound_data': self.loaded_sounds[predicted_class]
364
- }
365
- self.composition_layers.append(layer_info)
366
- self.movements_completed.add(predicted_class)
367
- result['sound_added'] = True
368
-
369
- print(f"DEBUG: Added layer {len(self.composition_layers)}, total layers now: {len(self.composition_layers)}")
370
- print(f"DEBUG: Composition layers: {[layer['movement'] for layer in self.composition_layers]}")
371
-
372
- # Return individual sound file for this classification
373
- sound_path = os.path.join(self.sound_dir, sound_file)
374
- result['audio_file'] = sound_path if os.path.exists(sound_path) else None
375
-
376
- # Also create mixed composition for potential saving (but don't return it)
377
- mixed_file = self.get_current_mixed_composition()
378
- result['mixed_composition'] = mixed_file
379
-
380
- self.cycle_stats['successful_classifications'] += 1
381
- cycle_size = len(self.current_movement_sequence)
382
- print(f"✓ Cycle {self.current_cycle+1}: Added {sound_file} for {predicted_class} ({len(self.movements_completed)}/{cycle_size} complete)")
383
-
384
- # Check if cycle is complete (all movements in current sequence completed)
385
- cycle_movements = self.all_movements if self.include_neutral_in_cycle else self.active_movements
386
- if len(self.movements_completed) >= len(cycle_movements):
387
- result['cycle_complete'] = True
388
- result['next_action'] = 'remap_sounds'
389
- self.cycle_complete = True
390
- cycle_size = len(cycle_movements)
391
- print(f"🎵 Cycle {self.current_cycle+1} complete! All {cycle_size} movements successfully classified!")
392
-
393
- return result
394
-
395
- def start_new_cycle(self, new_sound_mapping: Dict[str, str] = None):
396
- """Start a new composition cycle with optional new sound mapping."""
397
- if new_sound_mapping:
398
- self.current_sound_mapping.update(new_sound_mapping)
399
- print(f"Updated sound mapping for cycle {self.current_cycle+2}")
400
-
401
- self.current_cycle += 1
402
- self.current_step = 0
403
- self.cycle_complete = False
404
- self.cycle_stats['total_cycles'] += 1
405
-
406
- # Generate new random sequence for this cycle
407
- self._generate_new_sequence()
408
-
409
- print(f"🔄 Starting Cycle {self.current_cycle+1}")
410
- if self.current_cycle == 1:
411
- print("💪 Let's create your first brain-music composition!")
412
- elif self.current_cycle == 2:
413
- print("� Great progress! Let's create your second composition!")
414
- else:
415
- print("�🎯 Imagine ANY movement - you can choose the order!")
416
-
417
- def should_end_session(self) -> bool:
418
- """Check if the rehabilitation session should end after max cycles."""
419
- return self.completed_cycles >= self.max_cycles
420
-
421
- def complete_current_cycle(self):
422
- """Mark current cycle as complete and track progress."""
423
- self.completed_cycles += 1
424
- print(f"✅ Cycle {self.current_cycle} completed! ({self.completed_cycles}/{self.max_cycles} compositions finished)")
425
-
426
- def transition_to_dj_phase(self):
427
- """Transition from building phase to DJ effects phase."""
428
- # Only start DJ mode if all 4 sound layers are present (not just movements completed)
429
- unique_sounds = set()
430
- for layer in self.composition_layers:
431
- if layer.get('sound_file'):
432
- unique_sounds.add(layer['sound_file'])
433
- print(f"DEBUG: Unique sound files in composition_layers: {unique_sounds}")
434
- print(f"DEBUG: Number of unique sounds: {len(unique_sounds)}")
435
- if len(unique_sounds) >= 4:
436
- self.current_phase = "dj_effects"
437
- # self._create_mixed_composition()
438
- # just keep current stems running
439
- print("🎵 Composition Complete! Transitioning to DJ Effects Phase...")
440
- print("🎧 You are now the DJ! Use movements to control effects:")
441
- print(" 👈 Left Hand: Volume Fade")
442
- print(" 👉 Right Hand: High Pass Filter")
443
- print(" 🦵 Left Leg: Reverb Effect")
444
- print(" 🦵 Right Leg: Low Pass Filter")
445
- #print(" 👅 Tongue: Bass Boost")
446
- return True
447
- else:
448
- print("DEBUG: Not enough unique sounds to transition to DJ mode.")
449
- return False
450
-
451
- # def _create_mixed_composition(self):
452
- # """Create a mixed audio file from all completed layers."""
453
- # try:
454
- # import hashlib
455
- # movement_hash = hashlib.md5(str(sorted(self.movements_completed)).encode()).hexdigest()[:8]
456
- # self.mixed_composition_file = os.path.abspath(f"mixed_composition_{movement_hash}.wav")
457
-
458
- # # Try to use the first available completed movement's audio file
459
- # for movement in self.movements_completed:
460
- # if movement in self.current_sound_mapping and self.current_sound_mapping[movement] is not None:
461
- # sound_file = os.path.join(self.sound_dir, self.current_sound_mapping[movement])
462
- # if os.path.exists(sound_file):
463
- # self.mixed_composition_file = os.path.abspath(sound_file)
464
- # print(f"📀 Using existing audio as mixed composition: {self.mixed_composition_file}")
465
- # return
466
-
467
- # # If file already exists, use it
468
- # if os.path.exists(self.mixed_composition_file):
469
- # print(f"📀 Using existing mixed composition: {self.mixed_composition_file}")
470
- # return
471
-
472
- # # Create actual mixed composition by layering completed sounds
473
- # mixed_data = None
474
- # sample_rate = 44100 # Default sample rate
475
-
476
- # for movement in self.movements_completed:
477
- # if movement in self.current_sound_mapping and self.current_sound_mapping[movement] is not None:
478
- # sound_file = os.path.join(self.sound_dir, self.current_sound_mapping[movement])
479
- # if os.path.exists(sound_file):
480
- # try:
481
- # data, sr = sf.read(sound_file)
482
- # sample_rate = sr
483
-
484
- # # Convert stereo to mono
485
- # if len(data.shape) > 1:
486
- # data = np.mean(data, axis=1)
487
-
488
- # # Initialize or add to mixed data
489
- # if mixed_data is None:
490
- # mixed_data = data * 0.8 # Reduce volume to prevent clipping
491
- # else:
492
- # # Ensure same length by padding shorter audio
493
- # if len(data) > len(mixed_data):
494
- # mixed_data = np.pad(mixed_data, (0, len(data) - len(mixed_data)))
495
- # elif len(mixed_data) > len(data):
496
- # data = np.pad(data, (0, len(mixed_data) - len(data)))
497
-
498
- # # Mix the audio (layer them)
499
- # mixed_data += data * 0.8
500
- # except Exception as e:
501
- # print(f"Error mixing {sound_file}: {e}")
502
-
503
- # # Save mixed composition or create silent fallback
504
- # if mixed_data is not None:
505
- # # Normalize to prevent clipping
506
- # max_val = np.max(np.abs(mixed_data))
507
- # if max_val > 0.95:
508
- # mixed_data = mixed_data * 0.95 / max_val
509
-
510
- # sf.write(self.mixed_composition_file, mixed_data, sample_rate)
511
- # print(f"📀 Mixed composition created: {self.mixed_composition_file} (FILE SAVING ENABLED)")
512
- # else:
513
- # # Create silent fallback file
514
- # silent_data = np.zeros(int(sample_rate * 2)) # 2 seconds of silence
515
- # sf.write(self.mixed_composition_file, silent_data, sample_rate)
516
- # print(f"📀 Silent fallback composition created: {self.mixed_composition_file}")
517
-
518
- # except Exception as e:
519
- # print(f"Error creating mixed composition: {e}")
520
- # # Create minimal fallback file with actual content
521
- # self.mixed_composition_file = os.path.abspath("mixed_composition_fallback.wav")
522
- # try:
523
- # # Create a short silent audio file as fallback
524
- # sample_rate = 44100
525
- # silent_data = np.zeros(int(sample_rate * 2)) # 2 seconds of silence
526
- # sf.write(self.mixed_composition_file, silent_data, sample_rate)
527
- # print(f"📀 Silent fallback composition created: {self.mixed_composition_file}")
528
- # except Exception as fallback_error:
529
- # print(f"Failed to create fallback file: {fallback_error}")
530
- # self.mixed_composition_file = None
531
-
532
- def toggle_dj_effect(self, movement: str) -> dict:
533
- """Toggle a DJ effect for the given movement and process the corresponding layer."""
534
- if self.current_phase != "dj_effects":
535
- return {"effect_applied": False, "message": "Not in DJ effects phase"}
536
-
537
- if movement not in self.active_effects:
538
- return {"effect_applied": False, "message": f"Unknown movement: {movement}"}
539
-
540
- # Toggle effect
541
- self.active_effects[movement] = not self.active_effects[movement]
542
- effect_status = "ON" if self.active_effects[movement] else "OFF"
543
- effect_names = {
544
- "left_hand": "Echo",
545
- "right_hand": "High Pass Filter",
546
- "left_leg": "Reverb Effect",
547
- "right_leg": "Low Pass Filter",
548
- }
549
- effect_name = effect_names.get(movement, movement)
550
- print(f"🎛️ {effect_name}: {effect_status}")
551
-
552
- # Find the audio file for this movement
553
- sound_file = self.current_sound_mapping.get(movement)
554
- audio_path = os.path.join(self.sound_dir, sound_file) if sound_file else None
555
- processed_file = None
556
- if audio_path and os.path.exists(audio_path):
557
- processed_file = AudioEffectsProcessor.process_layer_with_effects(
558
- audio_path, movement, self.active_effects
559
- )
560
- else:
561
- print(f"No audio file found for movement: {movement}")
562
- processed_file = None
563
-
564
- # For DJ phase, always play all base layers (with effects if toggled)
565
- all_layers = {}
566
- for m in self.active_movements:
567
- sf_name = self.current_sound_mapping.get(m)
568
- apath = os.path.join(self.sound_dir, sf_name) if sf_name else None
569
- if apath and os.path.exists(apath):
570
- all_layers[m] = AudioEffectsProcessor.process_layer_with_effects(
571
- apath, m, self.active_effects
572
- )
573
- return {
574
- "effect_applied": True,
575
- "effect_name": effect_name,
576
- "effect_status": effect_status,
577
- "processed_layer": processed_file,
578
- "all_layers": all_layers
579
- }
580
-
581
- def get_cycle_success_rate(self) -> float:
582
- """Get success rate for current cycle."""
583
- if self.cycle_stats['total_attempts'] == 0:
584
- return 0.0
585
- return self.cycle_stats['successful_classifications'] / self.cycle_stats['total_attempts']
586
-
587
- def clear_composition(self):
588
- """Clear all composition layers."""
589
- self.composition_layers = []
590
- self.current_composition = None
591
- print("Composition cleared")
592
-
593
- def _get_audio_file_for_gradio(self, movement):
594
- """Get the actual audio file path for Gradio audio output"""
595
- if movement in self.current_sound_mapping:
596
- sound_file = self.current_sound_mapping[movement]
597
- audio_path = os.path.join(self.sound_dir, sound_file)
598
- if os.path.exists(audio_path):
599
- return audio_path
600
- return None
601
-
602
- def _mix_audio_files(self, audio_files: List[str]) -> str:
603
- """Mix multiple audio files into a single layered audio file."""
604
- if not audio_files:
605
- return None
606
-
607
- # Load all audio files and convert to mono
608
- audio_data_list = []
609
- sample_rate = None
610
- max_length = 0
611
-
612
- for file_path in audio_files:
613
- if os.path.exists(file_path):
614
- data, sr = sf.read(file_path)
615
-
616
- # Ensure data is numpy array and handle shape properly
617
- data = np.asarray(data)
618
-
619
- # Convert to mono if stereo/multi-channel
620
- if data.ndim > 1:
621
- if data.shape[1] > 1: # Multi-channel
622
- data = np.mean(data, axis=1)
623
- else: # Single channel with extra dimension
624
- data = data.flatten()
625
-
626
- # Ensure final data is 1D
627
- data = np.asarray(data).flatten()
628
-
629
- audio_data_list.append(data)
630
- if sample_rate is None:
631
- sample_rate = sr
632
- max_length = max(max_length, len(data))
633
-
634
- if not audio_data_list:
635
- return None
636
-
637
- # Pad all audio to same length and mix
638
- mixed_audio = np.zeros(max_length)
639
- for data in audio_data_list:
640
- # Ensure data is 1D (flatten any remaining multi-dimensional arrays)
641
- data_flat = np.asarray(data).flatten()
642
-
643
- # Pad or truncate to match max_length
644
- if len(data_flat) < max_length:
645
- padded = np.pad(data_flat, (0, max_length - len(data_flat)), 'constant')
646
- else:
647
- padded = data_flat[:max_length]
648
-
649
- # Add to mix (normalize to prevent clipping)
650
- mixed_audio += padded / len(audio_data_list)
651
-
652
- # Create a unique identifier for this composition
653
- import hashlib
654
- composition_hash = hashlib.md5(''.join(sorted(audio_files)).encode()).hexdigest()[:8]
655
-
656
- # FILE SAVING ENABLED: Return first available audio file instead of creating mixed composition
657
- mixed_audio_path = os.path.join(self.sound_dir, f"mixed_composition_{composition_hash}.wav")
658
-
659
- # Since file saving is disabled, use the first available audio file from the list
660
- if audio_files:
661
- # Use the first audio file as the "mixed" composition
662
- first_audio_file = os.path.join(self.sound_dir, audio_files[0])
663
- if os.path.exists(first_audio_file):
664
- print(f"DEBUG: Using first audio file as mixed composition: {os.path.basename(first_audio_file)}")
665
- # Estimate BPM from the mixed composition audio file
666
- try:
667
- import librosa
668
- y, sr = librosa.load(first_audio_file, sr=None)
669
- tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
670
- self.bpm = int(tempo)
671
- print(f"Estimated BPM from track: {self.bpm}")
672
- except Exception as e:
673
- print(f"Could not estimate BPM: {e}")
674
- return first_audio_file
675
-
676
- # Fallback: if mixed composition file already exists, use it
677
- if os.path.exists(mixed_audio_path):
678
- print(f"DEBUG: Reusing existing mixed audio file: {os.path.basename(mixed_audio_path)}")
679
- return mixed_audio_path
680
-
681
- # Final fallback: return first available base audio file
682
- base_files = ["1_SoundHelix-Song-6_(Vocals).wav", "1_SoundHelix-Song-6_(Drums).wav", "1_SoundHelix-Song-6_(Bass).wav", "1_SoundHelix-Song-6_(Other).wav"]
683
- for base_file in base_files:
684
- base_path = os.path.join(self.sound_dir, base_file)
685
- if os.path.exists(base_path):
686
- print(f"DEBUG: Using base audio file as fallback: {base_file} (FILE SAVING DISABLED)")
687
- return base_path
688
-
689
- return mixed_audio_path
690
-
691
- def get_current_mixed_composition(self) -> str:
692
- """Get the current composition as a mixed audio file."""
693
- # Get all audio files from current composition layers
694
- audio_files = []
695
- for layer in self.composition_layers:
696
- movement = layer.get('movement')
697
- if movement and movement in self.current_sound_mapping:
698
- sound_file = self.current_sound_mapping[movement]
699
- audio_path = os.path.join(self.sound_dir, sound_file)
700
- if os.path.exists(audio_path):
701
- audio_files.append(audio_path)
702
-
703
- # Debug: print current composition state
704
- print(f"DEBUG: Current composition has {len(self.composition_layers)} layers: {[layer.get('movement') for layer in self.composition_layers]}")
705
- print(f"DEBUG: Audio files to mix: {[os.path.basename(f) for f in audio_files]}")
706
-
707
- return self._mix_audio_files(audio_files)
708
-
709
- def get_composition_info(self) -> Dict:
710
- """Get comprehensive information about current composition."""
711
- layers_by_cycle = {}
712
- for layer in self.composition_layers:
713
- cycle = layer['cycle']
714
- if cycle not in layers_by_cycle:
715
- layers_by_cycle[cycle] = []
716
- layers_by_cycle[cycle].append({
717
- 'step': layer['step'],
718
- 'movement': layer['movement'],
719
- 'sound_file': layer['sound_file'],
720
- 'confidence': layer['confidence']
721
- })
722
-
723
- # Also track completed movements without sounds (like neutral)
724
- completed_without_sound = [mov for mov in self.movements_completed
725
- if self.current_sound_mapping.get(mov) is None]
726
-
727
- return {
728
- 'total_cycles': self.current_cycle + (1 if self.composition_layers else 0),
729
- 'current_cycle': self.current_cycle + 1,
730
- 'current_step': len(self.movements_completed) + 1, # Current step in cycle
731
- 'target_movement': self.get_current_target_movement(),
732
- 'movements_completed': len(self.movements_completed),
733
- 'movements_remaining': len(self.current_movement_sequence) - len(self.movements_completed),
734
- 'cycle_complete': self.cycle_complete,
735
- 'total_layers': len(self.composition_layers),
736
- 'completed_movements': list(self.movements_completed),
737
- 'completed_without_sound': completed_without_sound,
738
- 'layers_by_cycle': layers_by_cycle,
739
- 'current_mapping': self.current_sound_mapping.copy(),
740
- 'success_rate': self.get_cycle_success_rate(),
741
- 'stats': self.cycle_stats.copy()
742
- }
743
-
744
- def get_sound_mapping_options(self) -> Dict:
745
- """Get available sound mapping options for user customization."""
746
- return {
747
- 'movements': self.all_movements, # All possible movements for mapping
748
- 'available_sounds': self.available_sounds,
749
- 'current_mapping': self.current_sound_mapping.copy()
750
- }
751
-
752
- def update_sound_mapping(self, new_mapping: Dict[str, str]) -> bool:
753
- """Update the sound mapping for future cycles."""
754
- try:
755
- # Validate that all sounds exist
756
- for movement, sound_file in new_mapping.items():
757
- if sound_file not in self.available_sounds:
758
- print(f"Warning: Sound file {sound_file} not available")
759
- return False
760
-
761
- self.current_sound_mapping.update(new_mapping)
762
- print("✓ Sound mapping updated successfully")
763
- return True
764
- except Exception as e:
765
- print(f"Error updating sound mapping: {e}")
766
- return False
767
-
768
- def reset_composition(self):
769
- """Reset the entire composition to start fresh."""
770
- self.composition_layers = []
771
- self.current_cycle = 0
772
- self.current_step = 0
773
- self.cycle_complete = False
774
- self.cycle_stats = {
775
- 'total_cycles': 0,
776
- 'successful_classifications': 0,
777
- 'total_attempts': 0
778
- }
779
- print("🔄 Composition reset - ready for new session")
780
-
781
- def save_composition(self, output_path: str = "composition.wav"):
782
- """Save the current composition as a mixed audio file."""
783
- if not self.composition_layers:
784
- print("No layers to save")
785
- return False
786
-
787
- try:
788
- # For simplicity, we'll just save the latest layer
789
- # In a real implementation, you'd mix multiple audio tracks
790
- latest_layer = self.composition_layers[-1]
791
- sound_data = latest_layer['sound_data']
792
-
793
- sf.write(output_path, sound_data['data'], sound_data['sample_rate'])
794
- print(f"Composition saved to {output_path}")
795
- return True
796
-
797
- except Exception as e:
798
- print(f"Error saving composition: {e}")
799
- return False
800
-
801
- def get_available_sounds(self) -> List[str]:
802
- """Get list of available sound classes."""
803
- return list(self.loaded_sounds.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sound_manager.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sound Management System for EEG Motor Imagery Classification (Clean Transition Version)
3
+ -------------------------------------------------------------------------------
4
+ Handles sound mapping, layering, and music composition based on motor imagery predictions.
5
+ Supports seamless transition from building (layering) to DJ (effects) phase.
6
+ """
7
+
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import os
11
+ from typing import Dict, Optional, List
12
+ from pathlib import Path
13
+
14
+ class AudioEffectsProcessor:
15
+ @staticmethod
16
+ def apply_high_pass_filter(data: np.ndarray, samplerate: int, cutoff: float = 800.0) -> np.ndarray:
17
+ from scipy import signal
18
+ nyquist = samplerate / 2
19
+ normalized_cutoff = cutoff / nyquist
20
+ b, a = signal.butter(4, normalized_cutoff, btype='high', analog=False)
21
+ return signal.filtfilt(b, a, data)
22
+
23
+ @staticmethod
24
+ def apply_low_pass_filter(data: np.ndarray, samplerate: int, cutoff: float = 1200.0) -> np.ndarray:
25
+ from scipy import signal
26
+ nyquist = samplerate / 2
27
+ normalized_cutoff = cutoff / nyquist
28
+ b, a = signal.butter(4, normalized_cutoff, btype='low', analog=False)
29
+ return signal.filtfilt(b, a, data)
30
+
31
+ @staticmethod
32
+ def apply_reverb(data: np.ndarray, samplerate: int, room_size: float = 0.5) -> np.ndarray:
33
+ delay_samples = int(0.08 * samplerate)
34
+ decay = 0.4 * room_size
35
+ reverb_data = np.copy(data)
36
+ for i in range(3):
37
+ delay = delay_samples * (i + 1)
38
+ if delay < len(data):
39
+ gain = decay ** (i + 1)
40
+ reverb_data[delay:] += data[:-delay] * gain
41
+ return 0.7 * data + 0.3 * reverb_data
42
+
43
+ @staticmethod
44
+ def apply_echo(data: np.ndarray, samplerate: int, delay_time: float = 0.3, feedback: float = 0.4) -> np.ndarray:
45
+ delay_samples = int(delay_time * samplerate)
46
+ echo_data = np.copy(data)
47
+ for i in range(delay_samples, len(data)):
48
+ echo_data[i] += feedback * echo_data[i - delay_samples]
49
+ return 0.7 * data + 0.3 * echo_data
50
+
51
+ @staticmethod
52
+ def apply_compressor(data: np.ndarray, samplerate: int, threshold: float = 0.2, ratio: float = 4.0) -> np.ndarray:
53
+ # Simple compressor: reduce gain above threshold
54
+ compressed = np.copy(data)
55
+ over_threshold = np.abs(compressed) > threshold
56
+ compressed[over_threshold] = np.sign(compressed[over_threshold]) * (threshold + (np.abs(compressed[over_threshold]) - threshold) / ratio)
57
+ return compressed
58
+
59
+ @staticmethod
60
+ def process_layer_with_effects(audio_data: np.ndarray, samplerate: int, movement: str, active_effects: Dict[str, bool]) -> np.ndarray:
61
+ processed_data = np.copy(audio_data)
62
+ effect_map = {
63
+ "left_hand": AudioEffectsProcessor.apply_echo, # Echo
64
+ "right_hand": AudioEffectsProcessor.apply_low_pass_filter, # Low Pass
65
+ "left_leg": AudioEffectsProcessor.apply_compressor, # Compressor
66
+ "right_leg": AudioEffectsProcessor.apply_high_pass_filter, # High Pass
67
+ }
68
+ effect_func = effect_map.get(movement)
69
+ if active_effects.get(movement, False) and effect_func:
70
+ processed_data = effect_func(processed_data, samplerate)
71
+ return processed_data
72
+
73
+ class SoundManager:
74
+ def __init__(self, sound_dir: str = "sounds"):
75
+ self.available_sounds = [
76
+ "SoundHelix-Song-4_bass.wav",
77
+ "SoundHelix-Song-4_drums.wav",
78
+ "SoundHelix-Song-4_instruments.wav",
79
+ "SoundHelix-Song-4_vocals.wav"
80
+ ]
81
+ self.sound_dir = Path(sound_dir)
82
+ self.current_cycle = 0
83
+ self.current_step = 0
84
+ self.cycle_complete = False
85
+ self.completed_cycles = 0
86
+ self.max_cycles = 2
87
+ self.composition_layers = {}
88
+ self.current_phase = "building"
89
+ self.active_effects = {m: False for m in ["left_hand", "right_hand", "left_leg", "right_leg"]}
90
+ self.active_movements = ["left_hand", "right_hand", "left_leg", "right_leg"]
91
+ self.current_movement_sequence = []
92
+ self.movements_completed = set()
93
+ self.active_layers: Dict[str, str] = {}
94
+ self.loaded_sounds = {}
95
+ self._generate_new_sequence()
96
+ self._load_sound_files()
97
+ # Provide mapping from movement to sound file name for compatibility
98
+ self.current_sound_mapping = {m: f for m, f in zip(self.active_movements, self.available_sounds)}
99
+ # Track DJ effect trigger counts for each movement
100
+ self.dj_effect_counters = {m: 0 for m in self.active_movements}
101
+ self.cycle_stats = {'total_cycles': 0, 'successful_classifications': 0, 'total_attempts': 0}
102
+
103
+ def _load_sound_files(self):
104
+ self.loaded_sounds = {}
105
+ for movement, filename in self.current_sound_mapping.items():
106
+ file_path = self.sound_dir / filename
107
+ if file_path.exists():
108
+ data, sample_rate = sf.read(str(file_path))
109
+ if len(data.shape) > 1:
110
+ data = np.mean(data, axis=1)
111
+ self.loaded_sounds[movement] = {'data': data, 'sample_rate': sample_rate, 'sound_file': str(file_path)}
112
+
113
+ def _generate_new_sequence(self):
114
+ # Fixed movement order and mapping
115
+ self.current_movement_sequence = ["left_hand", "right_hand", "left_leg", "right_leg"]
116
+ self.current_sound_mapping = {
117
+ "left_hand": "SoundHelix-Song-4_instruments.wav",
118
+ "right_hand": "SoundHelix-Song-4_bass.wav",
119
+ "left_leg": "SoundHelix-Song-4_drums.wav",
120
+ "right_leg": "SoundHelix-Song-4_vocals.wav"
121
+ }
122
+ print(f"DEBUG: Fixed sound mapping for this cycle: {self.current_sound_mapping}")
123
+ self.movements_completed = set()
124
+ self.current_step = 0
125
+ self._load_sound_files()
126
+
127
+ def get_current_target_movement(self) -> str:
128
+ # Randomly select a movement from those not yet completed
129
+ import random
130
+ incomplete = [m for m in self.active_movements if m not in self.movements_completed]
131
+ if not incomplete:
132
+ print("DEBUG: All movements completed, cycle complete.")
133
+ return "cycle_complete"
134
+ movement = random.choice(incomplete)
135
+ print(f"DEBUG: Next target is {movement}, completed: {self.movements_completed}")
136
+ return movement
137
+
138
+
139
+ def process_classification(self, predicted_class: str, confidence: float, threshold: float = 0.7, force_add: bool = False) -> Dict:
140
+ result = {'sound_added': False, 'cycle_complete': False, 'audio_file': None}
141
+ # If force_add is True, allow adding sound for any valid movement not already completed
142
+ if force_add:
143
+ if (
144
+ confidence >= threshold and
145
+ predicted_class in self.loaded_sounds and
146
+ predicted_class not in self.composition_layers
147
+ ):
148
+ print(f"DEBUG: [FORCE] Adding sound for {predicted_class}")
149
+ sound_info = dict(self.loaded_sounds[predicted_class])
150
+ sound_info['confidence'] = confidence
151
+ self.composition_layers[predicted_class] = sound_info
152
+ self.movements_completed.add(predicted_class)
153
+ result['sound_added'] = True
154
+ else:
155
+ print("DEBUG: [FORCE] Not adding sound. Condition failed.")
156
+ else:
157
+ current_target = self.get_current_target_movement()
158
+ print(f"DEBUG: process_classification: predicted={predicted_class}, target={current_target}, confidence={confidence}, completed={self.movements_completed}")
159
+ if (
160
+ predicted_class == current_target and
161
+ confidence >= threshold and
162
+ predicted_class in self.loaded_sounds and
163
+ predicted_class not in self.composition_layers
164
+ ):
165
+ print(f"DEBUG: Adding sound for {predicted_class} (target={current_target})")
166
+ sound_info = dict(self.loaded_sounds[predicted_class])
167
+ sound_info['confidence'] = confidence
168
+ self.composition_layers[predicted_class] = sound_info
169
+ self.movements_completed.add(predicted_class)
170
+ result['sound_added'] = True
171
+ else:
172
+ print("DEBUG: Not adding sound. Condition failed.")
173
+ if len(self.movements_completed) >= len(self.active_movements):
174
+ result['cycle_complete'] = True
175
+ self.current_phase = "dj_effects"
176
+ return result
177
+
178
+ def start_new_cycle(self):
179
+ self.current_cycle += 1
180
+ self.current_step = 0
181
+ self.cycle_complete = False
182
+ self.cycle_stats['total_cycles'] += 1
183
+ self._generate_new_sequence()
184
+ self.composition_layers = {} # Clear layers for new cycle
185
+ self.movements_completed = set()
186
+ self.current_phase = "building"
187
+ self.active_layers = {}
188
+
189
+ def transition_to_dj_phase(self):
190
+ if len(self.composition_layers) >= len(self.active_movements):
191
+ self.current_phase = "dj_effects"
192
+ return True
193
+ return False
194
+
195
+ def toggle_dj_effect(self, movement: str, brief: bool = True, duration: float = 1.0) -> dict:
196
+ import threading
197
+ if self.current_phase != "dj_effects":
198
+ return {"effect_applied": False, "message": "Not in DJ effects phase"}
199
+ if movement not in self.active_effects:
200
+ return {"effect_applied": False, "message": f"Unknown movement: {movement}"}
201
+ # Only toggle effect every 4th time this movement is detected
202
+ self.dj_effect_counters[movement] += 1
203
+ if self.dj_effect_counters[movement] % 4 != 0:
204
+ print(f"🎛️ {movement}: Skipped effect toggle (count={self.dj_effect_counters[movement]})")
205
+ return {"effect_applied": False, "message": f"Effect for {movement} only toggled every 4th time (count={self.dj_effect_counters[movement]})"}
206
+ # Toggle effect ON
207
+ self.active_effects[movement] = True
208
+ effect_status = "ON"
209
+ print(f"🎛️ {movement}: {effect_status} (brief={brief}) [count={self.dj_effect_counters[movement]}]")
210
+ # Schedule effect OFF after duration if brief
211
+ def turn_off_effect():
212
+ self.active_effects[movement] = False
213
+ print(f"🎛️ {movement}: OFF (auto)")
214
+ if brief:
215
+ timer = threading.Timer(duration, turn_off_effect)
216
+ timer.daemon = True
217
+ timer.start()
218
+ return {"effect_applied": True, "effect_name": movement, "effect_status": effect_status, "count": self.dj_effect_counters[movement]}
219
+
220
+ def get_composition_info(self) -> Dict:
221
+ layers_by_cycle = {0: []}
222
+ for movement, layer_info in self.composition_layers.items():
223
+ confidence = layer_info.get('confidence', 0) if isinstance(layer_info, dict) else 0
224
+ layers_by_cycle[0].append({'movement': movement, 'confidence': confidence})
225
+ # Add DJ effect status for each movement
226
+ dj_effects_status = {m: self.active_effects.get(m, False) for m in self.active_movements}
227
+ return {'layers_by_cycle': layers_by_cycle, 'dj_effects_status': dj_effects_status}
228
+
229
+ def get_sound_mapping_options(self) -> Dict:
230
+ return {
231
+ 'movements': self.active_movements,
232
+ 'available_sounds': self.available_sounds,
233
+ 'current_mapping': {m: self.loaded_sounds[m]['sound_file'] for m in self.loaded_sounds}
234
+ }
235
+
236
+ def get_all_layers(self):
237
+ return {m: info['sound_file'] for m, info in self.composition_layers.items() if 'sound_file' in info}
source/eeg_motor_imagery.py CHANGED
@@ -103,7 +103,7 @@ model = ShallowFBCSPNet(
103
  ).to(device)
104
 
105
  # Load pretrained weights
106
- state_dict = torch.load("shallow_weights_all.pth", map_location=device)
107
  model.load_state_dict(state_dict)
108
 
109
  # === 6. Training ===
 
103
  ).to(device)
104
 
105
  # Load pretrained weights
106
+ state_dict = torch.load("model.pth", map_location=device)
107
  model.load_state_dict(state_dict)
108
 
109
  # === 6. Training ===