""" EEG Motor Imagery Music Composer - Clean Transition Version ========================================================= This version implements a clear separation between the building phase (layering sounds) and the DJ phase (effect control), with seamless playback of all layered sounds throughout both phases. """ # Set matplotlib backend to non-GUI for server/web use import matplotlib matplotlib.use('Agg') # Set backend BEFORE importing pyplot import matplotlib.pyplot as plt import os import gradio as gr import numpy as np from typing import Dict from sound_control import SoundManager from data_processor import EEGDataProcessor from classifier import MotorImageryClassifier from config import DEMO_DATA_PATHS, CONFIDENCE_THRESHOLD # --- Initialization --- app_state = { 'is_running': False, 'demo_data': None, 'demo_labels': None, 'composition_active': False, 'auto_mode': False, 'ch_names': None } sound_control= None data_processor = None classifier = None def lazy_init(): global sound_control, data_processor, classifier if sound_control is None: sound_control = SoundManager() if data_processor is None: data_processor = EEGDataProcessor() if classifier is None: classifier = MotorImageryClassifier() # Load demo data and model if not already loaded if app_state['demo_data'] is None or app_state['demo_labels'] is None or app_state['ch_names'] is None: existing_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)] if existing_files: app_state['demo_data'], app_state['demo_labels'], app_state['ch_names'] = data_processor.process_files(existing_files) else: app_state['demo_data'], app_state['demo_labels'], app_state['ch_names'] = None, None, None if app_state['demo_data'] is not None and classifier is not None and not hasattr(classifier, '_model_loaded'): classifier.load_model(n_chans=app_state['demo_data'].shape[1], n_times=app_state['demo_data'].shape[2]) classifier._model_loaded = True # --- Helper Functions --- def get_movement_sounds() -> Dict[str, str]: """Get the current sound files for each movement.""" sounds = {} # Add a static cache for audio file paths per movement and effect state if not hasattr(get_movement_sounds, 'audio_cache'): get_movement_sounds.audio_cache = {m: {False: None, True: None} for m in ['left_hand', 'right_hand', 'left_leg', 'right_leg']} get_movement_sounds.last_effect_state = {m: None for m in ['left_hand', 'right_hand', 'left_leg', 'right_leg']} # Add a static counter to track how many times each movement's audio is played if not hasattr(get_movement_sounds, 'play_counter'): get_movement_sounds.play_counter = {m: 0 for m in ['left_hand', 'right_hand', 'left_leg', 'right_leg']} get_movement_sounds.total_calls = 0 from sound_control import AudioEffectsProcessor import tempfile import soundfile as sf # If in DJ mode, use effect-processed file if effect is ON dj_mode = getattr(sound_control, 'current_phase', None) == 'dj_effects' for movement, sound_file in sound_control.current_sound_mapping.items(): if movement in ['left_hand', 'right_hand', 'left_leg', 'right_leg']: if sound_file is not None: sound_path = sound_control.sound_dir / sound_file if sound_path.exists(): # Sticky effect for all movements: if effect was ON, keep returning processed audio until next ON effect_on = dj_mode and sound_control.active_effects.get(movement, False) # If effect just turned ON, update sticky state if effect_on: get_movement_sounds.last_effect_state[movement] = True # If effect is OFF, but sticky is set, keep using processed audio elif get_movement_sounds.last_effect_state[movement]: effect_on = True else: effect_on = False # Check cache for this movement/effect state cached_path = get_movement_sounds.audio_cache[movement][effect_on] # Only regenerate if cache is empty or effect state just changed if cached_path is not None and get_movement_sounds.last_effect_state[movement] == effect_on: sounds[movement] = cached_path else: # Load audio data, sr = sf.read(str(sound_path)) if len(data.shape) > 1: data = np.mean(data, axis=1) # Fade-in: apply to all audio on restart (0.5s fade for more gradual effect) fade_duration = 10 # seconds fade_samples = int(fade_duration * sr) if fade_samples > 0 and fade_samples < len(data): fade_curve = np.linspace(0, 1, fade_samples) data[:fade_samples] = data[:fade_samples] * fade_curve if effect_on: # Apply effect processed = AudioEffectsProcessor.process_layer_with_effects( data, sr, movement, sound_control.active_effects ) # Save to temp file (persistent for this effect state) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=f'_{movement}_effect.wav') sf.write(tmp.name, processed, sr) get_movement_sounds.audio_cache[movement][True] = tmp.name sounds[movement] = tmp.name else: get_movement_sounds.audio_cache[movement][False] = str(sound_path.resolve()) sounds[movement] = str(sound_path.resolve()) get_movement_sounds.last_effect_state[movement] = effect_on get_movement_sounds.play_counter[movement] += 1 get_movement_sounds.total_calls += 1 return sounds def create_eeg_plot(eeg_data: np.ndarray, target_movement: str, predicted_name: str, confidence: float, sound_added: bool, ch_names=None) -> plt.Figure: '''Create a plot of EEG data with annotations. Plots C3 and C4 channels by name.''' if ch_names is None: ch_names = ['C3', 'C4'] # Find indices for C3 and C4 idx_c3 = ch_names.index('C3') if 'C3' in ch_names else 0 idx_c4 = ch_names.index('C4') if 'C4' in ch_names else 1 fig, axes = plt.subplots(1, 2, figsize=(10, 4)) axes = axes.flatten() time_points = np.arange(eeg_data.shape[1]) / 200 for i, idx in enumerate([idx_c3, idx_c4]): color = 'green' if sound_added else 'blue' axes[i].plot(time_points, eeg_data[idx], color=color, linewidth=1) axes[i].set_title(f'{ch_names[idx] if idx < len(ch_names) else f"Channel {idx+1}"}') axes[i].set_xlabel('Time (s)') axes[i].set_ylabel('Amplitude (µV)') axes[i].grid(True, alpha=0.3) title = f"Target: {target_movement.replace('_', ' ').title()} | Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})" fig.suptitle(title, fontsize=12, fontweight='bold') fig.tight_layout() plt.close(fig) return fig def format_composition_summary(composition_info: Dict) -> str: '''Format the composition summary for display. ''' if not composition_info.get('layers_by_cycle'): return "No composition layers yet" summary = [] for cycle, layers in composition_info['layers_by_cycle'].items(): summary.append(f"Cycle {cycle + 1}: {len(layers)} layers") for layer in layers: movement = layer.get('movement', 'unknown') confidence = layer.get('confidence', 0) summary.append(f" • {movement.replace('_', ' ').title()} ({confidence:.2f})") # DJ Effects Status removed from status tab as requested return "\n".join(summary) if summary else "No composition layers" # --- Main Logic --- def start_composition(): ''' Start the composition process. ''' global app_state lazy_init() if not app_state['composition_active']: app_state['composition_active'] = True sound_control.start_new_cycle() if app_state['demo_data'] is None: return "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available" # Force first trial to always be left_hand/instrumental if len(sound_control.movements_completed) == 0: next_movement = 'left_hand' left_hand_label = [k for k, v in classifier.class_names.items() if v == 'left_hand'][0] import numpy as np matching_indices = np.where(app_state['demo_labels'] == left_hand_label)[0] chosen_idx = np.random.choice(matching_indices) epoch_data = app_state['demo_data'][chosen_idx] true_label = left_hand_label true_label_name = 'left_hand' else: epoch_data, true_label = data_processor.simulate_real_time_data(app_state['demo_data'], app_state['demo_labels'], mode="class_balanced") true_label_name = classifier.class_names[true_label] next_movement = sound_control.get_current_target_movement() if next_movement == "cycle_complete": return continue_dj_phase() predicted_class, confidence, probabilities = classifier.predict(epoch_data) predicted_name = classifier.class_names[predicted_class] # Only add sound if confidence > threshold, predicted == true label, and true label matches the prompt if confidence > CONFIDENCE_THRESHOLD and predicted_name == true_label_name: result = sound_control.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD, force_add=True) else: result = {'sound_added': False} fig = create_eeg_plot(epoch_data, true_label_name, predicted_name, confidence, result['sound_added'], app_state.get('ch_names')) # Only play completed movement sounds (layered) sounds = get_movement_sounds() completed_movements = sound_control.movements_completed # Assign audio paths only for completed movements left_hand_audio = sounds.get('left_hand') if 'left_hand' in completed_movements else None right_hand_audio = sounds.get('right_hand') if 'right_hand' in completed_movements else None left_leg_audio = sounds.get('left_leg') if 'left_leg' in completed_movements else None right_leg_audio = sounds.get('right_leg') if 'right_leg' in completed_movements else None # 2. Movement Commands: show mapping for all movements movement_emojis = { "left_hand": "🫲", "right_hand": "🫱", "left_leg": "🦵", "right_leg": "🦵", } movement_command_lines = [] # Show 'Now Playing' for all completed movements (layers that are currently playing) completed_movements = sound_control.movements_completed for movement in ["left_hand", "right_hand", "left_leg", "right_leg"]: sound_file = sound_control.current_sound_mapping.get(movement, "") instrument_type = "" for key in ["bass", "drums", "instruments", "vocals"]: if key in sound_file.lower(): instrument_type = key if key != "instruments" else "instrument" break pretty_movement = movement.replace("_", " ").title() # Always use 'Instruments' (plural) for the left hand stem if movement == "left_hand" and instrument_type.lower() == "instrument": pretty_instrument = "Instruments" else: pretty_instrument = instrument_type.capitalize() if instrument_type else "--" emoji = movement_emojis.get(movement, "") # Add 'Now Playing' indicator for all completed movements if movement in completed_movements: movement_command_lines.append(f"{emoji} {pretty_movement}: {pretty_instrument} ▶️ Now Playing") else: movement_command_lines.append(f"{emoji} {pretty_movement}: {pretty_instrument}") movement_command_text = "🎼 Composition Mode - Movement to Stems Mapping\n" + "\n".join(movement_command_lines) # 3. Next Trial: always prompt user next_trial_text = "Imagine next movement" composition_info = sound_control.get_composition_info() status_text = format_composition_summary(composition_info) return ( movement_command_text, next_trial_text, fig, left_hand_audio, right_hand_audio, left_leg_audio, right_leg_audio, status_text ) def continue_dj_phase(): ''' Continue in DJ phase, applying effects and always playing all layered sounds. ''' global app_state if not app_state['composition_active']: return "❌ Not active", "❌ Not active", "❌ Not active", None, None, None, None, None, None, "Click 'Start Composing' first" if app_state['demo_data'] is None: return "❌ No data", "❌ No data", "❌ No data", None, None, None, None, None, None, "No EEG data available" # DJ phase: enforce strict DJ effect order epoch_data, true_label = data_processor.simulate_real_time_data(app_state['demo_data'], app_state['demo_labels'], mode="class_balanced") predicted_class, confidence, probabilities = classifier.predict(epoch_data) predicted_name = classifier.class_names[predicted_class] # Strict DJ order: right_hand, right_leg, left_leg, left_hand if not hasattr(continue_dj_phase, 'dj_order'): continue_dj_phase.dj_order = ["right_hand", "right_leg", "left_leg", "left_hand"] continue_dj_phase.dj_index = 0 # Find the next movement in the DJ order that hasn't been toggled yet (using effect counters) while continue_dj_phase.dj_index < 4: next_movement = continue_dj_phase.dj_order[continue_dj_phase.dj_index] # Only proceed if the predicted movement matches the next in order if predicted_name == next_movement: break else: # Ignore this prediction, do not apply effect next_trial_text = "Imagine next movement" # UI update: show which movement is expected # Always play all completed movement sounds (layered) sounds = get_movement_sounds() completed_movements = sound_control.movements_completed left_hand_audio = sounds.get('left_hand') if 'left_hand' in completed_movements else None right_hand_audio = sounds.get('right_hand') if 'right_hand' in completed_movements else None left_leg_audio = sounds.get('left_leg') if 'left_leg' in completed_movements else None right_leg_audio = sounds.get('right_leg') if 'right_leg' in completed_movements else None movement_map = { "left_hand": {"effect": "Fade In/Out", "instrument": "Instruments"}, "right_hand": {"effect": "Low Pass", "instrument": "Bass"}, "left_leg": {"effect": "Compressor", "instrument": "Drums"}, "right_leg": {"effect": "Echo", "instrument": "Vocals"}, } emoji_map = {"left_hand": "🫲", "right_hand": "🫱", "left_leg": "🦵", "right_leg": "🦵"} movement_command_lines = [] for m in ["left_hand", "right_hand", "left_leg", "right_leg"]: status = "ON" if sound_control.active_effects.get(m, False) else "off" movement_command_lines.append(f"{emoji_map[m]} {m.replace('_', ' ').title()}: {movement_map[m]['effect']} [{'ON' if status == 'ON' else 'off'}] → {movement_map[m]['instrument']}") target_text = "🎧 DJ Mode - Movement to Effect Mapping\n" + "\n".join(movement_command_lines) composition_info = sound_control.get_composition_info() status_text = format_composition_summary(composition_info) fig = create_eeg_plot(epoch_data, classifier.class_names[true_label], predicted_name, confidence, False, app_state.get('ch_names')) return ( target_text, # Movement Commands (textbox) next_trial_text, # Next Trial (textbox) fig, # EEG Plot (plot) left_hand_audio, # Left Hand (audio) right_hand_audio, # Right Hand (audio) left_leg_audio, # Left Leg (audio) right_leg_audio, # Right Leg (audio) status_text, # Composition Status (textbox) gr.update(), # Timer (update object) gr.update() # Continue DJ Button (update object) ) # If correct movement, apply effect and advance order effect_applied = False if confidence > CONFIDENCE_THRESHOLD and predicted_name == continue_dj_phase.dj_order[continue_dj_phase.dj_index]: result = sound_control.toggle_dj_effect(predicted_name, brief=True, duration=1.0) effect_applied = result.get("effect_applied", False) continue_dj_phase.dj_index += 1 else: result = None fig = create_eeg_plot(epoch_data, classifier.class_names[true_label], predicted_name, confidence, effect_applied, app_state.get('ch_names')) # Always play all completed movement sounds (layered) sounds = get_movement_sounds() completed_movements = sound_control.movements_completed left_hand_audio = sounds.get('left_hand') if 'left_hand' in completed_movements else None right_hand_audio = sounds.get('right_hand') if 'right_hand' in completed_movements else None left_leg_audio = sounds.get('left_leg') if 'left_leg' in completed_movements else None right_leg_audio = sounds.get('right_leg') if 'right_leg' in completed_movements else None # Show DJ effect mapping for each movement with ON/OFF status and correct instrument mapping movement_map = { "left_hand": {"effect": "Fade In/Out", "instrument": "Instruments"}, "right_hand": {"effect": "Low Pass", "instrument": "Bass"}, "left_leg": {"effect": "Compressor", "instrument": "Drums"}, "right_leg": {"effect": "Echo", "instrument": "Vocals"}, } emoji_map = {"left_hand": "🫲", "right_hand": "🫱", "left_leg": "🦵", "right_leg": "🦵"} # Get effect ON/OFF status from sound_control.active_effects movement_command_lines = [] for m in ["left_hand", "right_hand", "left_leg", "right_leg"]: # Show [ON] only if effect is currently active (True), otherwise [off] status = "ON" if sound_control.active_effects.get(m, False) else "off" movement_command_lines.append(f"{emoji_map[m]} {m.replace('_', ' ').title()}: {movement_map[m]['effect']} [{'ON' if status == 'ON' else 'off'}] → {movement_map[m]['instrument']}") target_text = "🎧 DJ Mode - Movement to Effect Mapping\n" + "\n".join(movement_command_lines) # In DJ mode, Next Trial should only show the prompt, not the predicted/target movement predicted_text = "Imagine next movement" composition_info = sound_control.get_composition_info() status_text = format_composition_summary(composition_info) # Ensure exactly 10 outputs: [textbox, textbox, plot, audio, audio, audio, audio, textbox, timer, button] # Use fig for the plot, and fill all outputs with correct types return ( target_text, # Movement Commands (textbox) predicted_text, # Next Trial (textbox) fig, # EEG Plot (plot) left_hand_audio, # Left Hand (audio) right_hand_audio, # Right Hand (audio) left_leg_audio, # Left Leg (audio) right_leg_audio, # Right Leg (audio) status_text, # Composition Status (textbox) gr.update(), # Timer (update object) gr.update() # Continue DJ Button (update object) ) # --- Gradio UI --- def create_interface(): ''' Create the Gradio interface. ''' with gr.Blocks(title="EEG Motor Imagery Music Composer", theme=gr.themes.Citrus()) as demo: with gr.Tabs(): with gr.TabItem("🎵 Automatic Music Composer"): gr.Markdown("# 🧠 NeuroMusic Studio: An accessible, easy to use motor rehabilitation device.") gr.Markdown(""" **How it works:** 1. **Compose:** Imagine moving your left hand, right hand, left leg, or right leg to add musical layers. Each correct, high-confidence prediction adds a sound. Just follow the prompts. 2. **DJ Mode:** After all four layers are added, you can apply effects and remix your composition using new brain commands. > **Tip:** In DJ mode, each effect is triggered only every 4th time you repeat a movement, to keep playback smooth. Commands and controls update as you progress. Just follow the on-screen instructions! """) with gr.Row(): with gr.Column(scale=2): start_btn = gr.Button("🎵 Start Composing", variant="primary", size="lg") stop_btn = gr.Button("🛑 Stop", variant="stop", size="md") continue_btn = gr.Button("⏭️ Continue DJ Phase", variant="primary", size="lg", visible=False) timer = gr.Timer(value=1.0, active=False) # 4 second intervals predicted_display = gr.Textbox(label="🧠 Movement Commands", interactive=False, value="--", lines=4) timer_display = gr.Textbox(label="⏱️ Next Trial", interactive=False, value="--") eeg_plot = gr.Plot(label="EEG Data Visualization") with gr.Column(scale=1): left_hand_sound = gr.Audio(label="🫲 Left Hand", interactive=False, autoplay=True, visible=True) right_hand_sound = gr.Audio(label="🫱 Right Hand", interactive=False, autoplay=True, visible=True) left_leg_sound = gr.Audio(label="🦵 Left Leg", interactive=False, autoplay=True, visible=True) right_leg_sound = gr.Audio(label="🦵 Right Leg", interactive=False, autoplay=True, visible=True) composition_status = gr.Textbox(label="Composition Status", interactive=False, lines=5) def start_and_activate_timer(): ''' Start composing and activate timer for trials. ''' result = start_composition() last_trial_result[:] = result # Initialize with first trial result if "DJ Mode" not in result[0]: return (*result, gr.update(active=True), gr.update(visible=False)) else: return (*result, gr.update(active=False), gr.update(visible=True)) # ITI logic: 3s blank, 1s prompt, then trial timer_counter = {"count": 0} last_trial_result = [None] * 9 # Adjust length to match your outputs def timer_tick(): ''' Timer tick handler for ITI and trials. ''' # 0,1,2: blank, 3: prompt, 4: trial if timer_counter["count"] < 3: timer_counter["count"] += 1 # Show blank prompt, keep last outputs if len(last_trial_result) == 8: return (*last_trial_result, gr.update(active=True), gr.update(visible=False)) elif len(last_trial_result) == 10: # DJ mode: blank prompt result = list(last_trial_result) result[1] = "" return tuple(result) else: raise ValueError(f"Unexpected last_trial_result length: {len(last_trial_result)}") elif timer_counter["count"] == 3: timer_counter["count"] += 1 # Show prompt result = list(last_trial_result) result[1] = "Imagine next movement" if len(result) == 8: return (*result, gr.update(active=True), gr.update(visible=False)) elif len(result) == 10: return tuple(result) else: raise ValueError(f"Unexpected result length in prompt: {len(result)}") else: timer_counter["count"] = 0 # Run trial result = list(start_composition()) last_trial_result[:] = result # Save for next blanks/prompts if len(result) == 8: # Pre-DJ mode: add timer and button updates if any(isinstance(x, str) and "DJ Mode" in x for x in result): return (*result, gr.update(active=False), gr.update(visible=True)) else: return (*result, gr.update(active=True), gr.update(visible=False)) elif len(result) == 10: return tuple(result) else: raise ValueError(f"Unexpected result length in timer_tick: {len(result)}") def continue_dj(): ''' Continue DJ phase from button click. ''' result = continue_dj_phase() if len(result) == 8: return (*result, gr.update(active=False), gr.update(visible=True)) elif len(result) == 10: return result else: raise ValueError(f"Unexpected result length in continue_dj: {len(result)}") start_btn.click( fn=start_and_activate_timer, outputs=[predicted_display, timer_display, eeg_plot, left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status, timer, continue_btn] ) timer_event = timer.tick( fn=timer_tick, outputs=[predicted_display, timer_display, eeg_plot, left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status, timer, continue_btn] ) def stop_composing(): ''' Stop composing and reset state (works in both building and DJ mode). ''' timer_counter["count"] = 0 app_state['composition_active'] = False # Ensure new cycle on next start # Reset sound_control state for new session sound_control.current_phase = "building" sound_control.composition_layers = {} sound_control.movements_completed = set() sound_control.active_effects = {m: False for m in ["left_hand", "right_hand", "left_leg", "right_leg"]} # Clear static audio cache in get_movement_sounds if hasattr(get_movement_sounds, 'audio_cache'): for m in get_movement_sounds.audio_cache: get_movement_sounds.audio_cache[m][True] = None get_movement_sounds.audio_cache[m][False] = None if hasattr(get_movement_sounds, 'last_effect_state'): for m in get_movement_sounds.last_effect_state: get_movement_sounds.last_effect_state[m] = None if hasattr(get_movement_sounds, 'play_counter'): for m in get_movement_sounds.play_counter: get_movement_sounds.play_counter[m] = 0 get_movement_sounds.total_calls = 0 # Clear UI and deactivate timer, hide continue button, clear all audio last_trial_result[:] = ["--", "Stopped", None, None, None, None, None, "Stopped"] return ("--", "Stopped", None, None, None, None, None, "Stopped", gr.update(active=False), gr.update(visible=False)) stop_btn.click( fn=stop_composing, outputs=[predicted_display, timer_display, eeg_plot, left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status, timer, continue_btn], cancels=[timer_event] ) continue_btn.click( fn=continue_dj, outputs=[predicted_display, timer_display, eeg_plot, left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, composition_status, timer, continue_btn] ) with gr.TabItem("📝 Manual Classifier"): gr.Markdown("# Manual Classifier") gr.Markdown("Select a movement and run the classifier manually on a random epoch for that movement. Results will be accumulated below.") movement_dropdown = gr.Dropdown(choices=["left_hand", "right_hand", "left_leg", "right_leg"], label="Select Movement") manual_btn = gr.Button("Run Classifier", variant="primary") manual_predicted = gr.Textbox(label="Predicted Class", interactive=False) manual_confidence = gr.Textbox(label="Confidence", interactive=False) manual_plot = gr.Plot(label="EEG Data Visualization") manual_probs = gr.Plot(label="Class Probabilities") manual_confmat = gr.Plot(label="Confusion Matrix (Session)") # Session state for confusion matrix from collections import defaultdict session_confmat = defaultdict(lambda: defaultdict(int)) def manual_classify(selected_movement): ''' Manually classify a random epoch for the selected movement. ''' import matplotlib.pyplot as plt import numpy as np if app_state['demo_data'] is None or app_state['demo_labels'] is None: return "No data", "No data", None, None, None label_idx = [k for k, v in classifier.class_names.items() if v == selected_movement][0] matching_indices = np.where(app_state['demo_labels'] == label_idx)[0] if len(matching_indices) == 0: return "No data for this movement", "", None, None, None chosen_idx = np.random.choice(matching_indices) epoch_data = app_state['demo_data'][chosen_idx] predicted_class, confidence, probs = classifier.predict(epoch_data) predicted_name = classifier.class_names[predicted_class] # Update confusion matrix session_confmat[selected_movement][predicted_name] += 1 # Plot confusion matrix classes = ["left_hand", "right_hand", "left_leg", "right_leg"] confmat = np.zeros((4, 4), dtype=int) for i, true_m in enumerate(classes): for j, pred_m in enumerate(classes): confmat[i, j] = session_confmat[true_m][pred_m] fig_confmat, ax = plt.subplots(figsize=(4, 4)) ax.imshow(confmat, cmap="Blues") ax.set_xticks(np.arange(4)) ax.set_yticks(np.arange(4)) ax.set_xticklabels(classes, rotation=45, ha="right") ax.set_yticklabels(classes) ax.set_xlabel("Predicted") ax.set_ylabel("True") for i in range(4): for j in range(4): ax.text(j, i, str(confmat[i, j]), ha="center", va="center", color="black") fig_confmat.tight_layout() # Plot class probabilities if isinstance(probs, dict): probs_list = [probs.get(cls, 0.0) for cls in classes] else: probs_list = list(probs) fig_probs, ax_probs = plt.subplots(figsize=(4, 2)) ax_probs.bar(classes, probs_list) ax_probs.set_ylabel("Probability") ax_probs.set_ylim(0, 1) fig_probs.tight_layout() # EEG plot fig = create_eeg_plot(epoch_data, selected_movement, predicted_name, confidence, False, app_state.get('ch_names')) # Close all open figures to avoid warnings plt.close(fig_confmat) plt.close(fig_probs) plt.close(fig) return predicted_name, f"{confidence:.2f}", fig, fig_probs, fig_confmat manual_btn.click( fn=manual_classify, inputs=[movement_dropdown], outputs=[manual_predicted, manual_confidence, manual_plot, manual_probs, manual_confmat] ) return demo if __name__ == "__main__": print("[DEBUG] Starting app main block...") demo = create_interface() print("[DEBUG] Gradio interface created. Launching...") demo.launch(server_name="0.0.0.0", server_port=7860)