Spaces:
Sleeping
Sleeping
Commit
Β·
fa96cf5
0
Parent(s):
Initial commit: EEG Motor Imagery Music Composer
Browse files- Gradio-based web application for EEG motor imagery music composition
- Real-time EEG signal processing and classification
- Interactive music building with 5 movement types (left/right hand, left/right leg, tongue)
- DJ effects phase with audio processing
- File saving currently disabled for testing
- Includes pre-trained shallow neural network model
- Audio state management to prevent Gradio component restarts
- .gitattributes +3 -0
- .gitignore +59 -0
- README.md +78 -0
- app.py +1168 -0
- classifier.py +179 -0
- config.py +93 -0
- data_processor.py +257 -0
- demo.py +96 -0
- enhanced_utils.py +236 -0
- requirements.txt +14 -0
- sound_library.py +783 -0
- source/eeg_motor_imagery.py +142 -0
- utils.py +226 -0
.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/*.mat filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
sounds/*.wav filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
model.pth filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
env/
|
| 8 |
+
build/
|
| 9 |
+
develop-eggs/
|
| 10 |
+
dist/
|
| 11 |
+
downloads/
|
| 12 |
+
eggs/
|
| 13 |
+
.eggs/
|
| 14 |
+
lib/
|
| 15 |
+
lib64/
|
| 16 |
+
parts/
|
| 17 |
+
sdist/
|
| 18 |
+
var/
|
| 19 |
+
wheels/
|
| 20 |
+
*.egg-info/
|
| 21 |
+
.installed.cfg
|
| 22 |
+
*.egg
|
| 23 |
+
|
| 24 |
+
# Virtual Environment
|
| 25 |
+
.venv/
|
| 26 |
+
venv/
|
| 27 |
+
ENV/
|
| 28 |
+
env/
|
| 29 |
+
|
| 30 |
+
# IDE
|
| 31 |
+
.vscode/
|
| 32 |
+
.idea/
|
| 33 |
+
*.swp
|
| 34 |
+
*.swo
|
| 35 |
+
*~
|
| 36 |
+
|
| 37 |
+
# OS
|
| 38 |
+
.DS_Store
|
| 39 |
+
.DS_Store?
|
| 40 |
+
._*
|
| 41 |
+
.Spotlight-V100
|
| 42 |
+
.Trashes
|
| 43 |
+
ehthumbs.db
|
| 44 |
+
Thumbs.db
|
| 45 |
+
|
| 46 |
+
# Project specific
|
| 47 |
+
*.wav
|
| 48 |
+
*.mp3
|
| 49 |
+
*.pth
|
| 50 |
+
mixed_composition_*.wav
|
| 51 |
+
*_fx_*.wav
|
| 52 |
+
app_old.py
|
| 53 |
+
app_new.py
|
| 54 |
+
test_mixed.wav
|
| 55 |
+
silent.wav
|
| 56 |
+
|
| 57 |
+
# Data files
|
| 58 |
+
data/
|
| 59 |
+
*.mat
|
README.md
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# π§ EEG Motor Imagery Music Composer
|
| 2 |
+
|
| 3 |
+
A sophisticated machine learning application that transforms brain signals into music compositions using motor imagery classification. This system uses a trained ShallowFBCSPNet model to classify different motor imagery tasks from EEG data and creates layered musical compositions based on the classification results.
|
| 4 |
+
|
| 5 |
+
## π― Features
|
| 6 |
+
|
| 7 |
+
- **Real-time EEG Classification**: Uses ShallowFBCSPNet architecture for motor imagery classification
|
| 8 |
+
- **Music Composition**: Automatically creates layered music compositions from classification results
|
| 9 |
+
- **Interactive Gradio Interface**: User-friendly web interface for real-time interaction
|
| 10 |
+
- **Six Motor Imagery Classes**: Left/right hand, left/right leg, tongue, and neutral states
|
| 11 |
+
- **Sound Mapping**: Each motor imagery class is mapped to different musical instruments
|
| 12 |
+
- **Composition Management**: Save, clear, and manage your musical creations
|
| 13 |
+
|
| 14 |
+
## ποΈ Architecture
|
| 15 |
+
|
| 16 |
+
### Project Structure
|
| 17 |
+
```
|
| 18 |
+
βββ app.py # Main Gradio application
|
| 19 |
+
βββ classifier.py # Motor imagery classifier with ShallowFBCSPNet
|
| 20 |
+
βββ data_processor.py # EEG data loading and preprocessing
|
| 21 |
+
βββ sound_library.py # Sound management and composition system
|
| 22 |
+
βββ config.py # Configuration settings
|
| 23 |
+
βββ requirements.txt # Python dependencies
|
| 24 |
+
βββ SoundHelix-Song-1/ # Audio files for different instruments
|
| 25 |
+
β βββ bass.wav
|
| 26 |
+
β βββ drums.wav
|
| 27 |
+
β βββ other.wav
|
| 28 |
+
β βββ vocals.wav
|
| 29 |
+
βββ src/ # Additional source files
|
| 30 |
+
βββ model.py
|
| 31 |
+
βββ preprocessing.py
|
| 32 |
+
βββ train.py
|
| 33 |
+
βββ visualize.py
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### System Components
|
| 37 |
+
|
| 38 |
+
1. **EEGDataProcessor** (`data_processor.py`)
|
| 39 |
+
- Loads and processes .mat EEG files
|
| 40 |
+
- Handles epoching and preprocessing
|
| 41 |
+
- Simulates real-time data for demo purposes
|
| 42 |
+
|
| 43 |
+
2. **MotorImageryClassifier** (`classifier.py`)
|
| 44 |
+
- Implements ShallowFBCSPNet model
|
| 45 |
+
- Performs real-time classification
|
| 46 |
+
- Provides confidence scores and probability distributions
|
| 47 |
+
|
| 48 |
+
3. **SoundManager** (`sound_library.py`)
|
| 49 |
+
- Maps classifications to audio files
|
| 50 |
+
- Manages composition layers
|
| 51 |
+
- Handles audio file loading and playback
|
| 52 |
+
|
| 53 |
+
4. **Gradio Interface** (`app.py`)
|
| 54 |
+
- Web-based user interface
|
| 55 |
+
- Real-time visualization
|
| 56 |
+
- Composition management tools
|
| 57 |
+
|
| 58 |
+
## π Quick Start
|
| 59 |
+
|
| 60 |
+
### Requirements
|
| 61 |
+
|
| 62 |
+
Python 3.9β3.11 recommended. Install dependencies:
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
python -m pip install -r requirements.txt
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
### How to run (Gradio)
|
| 69 |
+
|
| 70 |
+
Local launch:
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
python app.py
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
This starts a server on `http://127.0.0.1:7860` by default.
|
| 77 |
+
|
| 78 |
+
#
|
app.py
ADDED
|
@@ -0,0 +1,1168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EEG Motor Imagery Music Composer - Redesigned Interface
|
| 3 |
+
=======================================================
|
| 4 |
+
Brain-Computer Interface that creates music compositions based on imagined movements.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import time
|
| 11 |
+
import threading
|
| 12 |
+
from typing import Dict, Tuple, Any, List
|
| 13 |
+
|
| 14 |
+
# Import our custom modules
|
| 15 |
+
from sound_library import SoundManager
|
| 16 |
+
from data_processor import EEGDataProcessor
|
| 17 |
+
from classifier import MotorImageryClassifier
|
| 18 |
+
from config import DEMO_DATA_PATHS, CLASS_NAMES, CONFIDENCE_THRESHOLD
|
| 19 |
+
|
| 20 |
+
def validate_data_setup() -> str:
|
| 21 |
+
"""Validate that required data files are available."""
|
| 22 |
+
missing_files = []
|
| 23 |
+
|
| 24 |
+
for subject_id, path in DEMO_DATA_PATHS.items():
|
| 25 |
+
try:
|
| 26 |
+
import os
|
| 27 |
+
if not os.path.exists(path):
|
| 28 |
+
missing_files.append(f"Subject {subject_id}: {path}")
|
| 29 |
+
except Exception as e:
|
| 30 |
+
missing_files.append(f"Subject {subject_id}: Error checking {path}")
|
| 31 |
+
|
| 32 |
+
if missing_files:
|
| 33 |
+
return f"β Missing data files:\n" + "\n".join(missing_files)
|
| 34 |
+
return "β
All data files found"
|
| 35 |
+
|
| 36 |
+
# Global app state
|
| 37 |
+
app_state = {
|
| 38 |
+
'is_running': False,
|
| 39 |
+
'demo_data': None,
|
| 40 |
+
'demo_labels': None,
|
| 41 |
+
'classification_history': [],
|
| 42 |
+
'composition_active': False,
|
| 43 |
+
'auto_mode': False,
|
| 44 |
+
'last_audio_state': {
|
| 45 |
+
'left_hand_audio': None,
|
| 46 |
+
'right_hand_audio': None,
|
| 47 |
+
'left_leg_audio': None,
|
| 48 |
+
'right_leg_audio': None,
|
| 49 |
+
'tongue_audio': None
|
| 50 |
+
}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# Initialize components
|
| 54 |
+
print("π§ EEG Motor Imagery Music Composer")
|
| 55 |
+
print("=" * 50)
|
| 56 |
+
print("Starting Gradio application...")
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
sound_manager = SoundManager()
|
| 60 |
+
data_processor = EEGDataProcessor()
|
| 61 |
+
classifier = MotorImageryClassifier()
|
| 62 |
+
|
| 63 |
+
# Load demo data
|
| 64 |
+
import os
|
| 65 |
+
existing_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)]
|
| 66 |
+
if existing_files:
|
| 67 |
+
app_state['demo_data'], app_state['demo_labels'] = data_processor.process_files(existing_files)
|
| 68 |
+
else:
|
| 69 |
+
app_state['demo_data'], app_state['demo_labels'] = None, None
|
| 70 |
+
|
| 71 |
+
if app_state['demo_data'] is not None:
|
| 72 |
+
# Initialize classifier with proper dimensions
|
| 73 |
+
classifier.load_model(n_chans=app_state['demo_data'].shape[1], n_times=app_state['demo_data'].shape[2])
|
| 74 |
+
print(f"β
Pre-trained model loaded successfully from {classifier.model_path}")
|
| 75 |
+
print(f"Pre-trained Demo: {len(app_state['demo_data'])} samples from {len(existing_files)} subjects")
|
| 76 |
+
else:
|
| 77 |
+
print("β οΈ No demo data loaded - check your .mat files")
|
| 78 |
+
|
| 79 |
+
print(f"Available sound classes: {list(sound_manager.current_sound_mapping.keys())}")
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"β Error during initialization: {e}")
|
| 83 |
+
raise RuntimeError(
|
| 84 |
+
"Cannot initialize app without real EEG data. "
|
| 85 |
+
"Please check your data files and paths."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def get_movement_sounds() -> Dict[str, str]:
|
| 89 |
+
"""Get the current sound files for each movement."""
|
| 90 |
+
sounds = {}
|
| 91 |
+
for movement, sound_file in sound_manager.current_sound_mapping.items():
|
| 92 |
+
if movement in ['left_hand', 'right_hand', 'left_leg', "right_leg", 'tongue']: # Only show main movements
|
| 93 |
+
if sound_file is not None: # Check if sound_file is not None
|
| 94 |
+
sound_path = sound_manager.sound_dir / sound_file
|
| 95 |
+
if sound_path.exists():
|
| 96 |
+
# Convert to absolute path for Gradio audio components
|
| 97 |
+
sounds[movement] = str(sound_path.resolve())
|
| 98 |
+
return sounds
|
| 99 |
+
|
| 100 |
+
def start_composition():
|
| 101 |
+
"""Start the composition process and perform initial classification."""
|
| 102 |
+
global app_state
|
| 103 |
+
|
| 104 |
+
# Only start new cycle if not already active
|
| 105 |
+
if not app_state['composition_active']:
|
| 106 |
+
app_state['composition_active'] = True
|
| 107 |
+
sound_manager.start_new_cycle() # Reset composition only when starting fresh
|
| 108 |
+
|
| 109 |
+
if app_state['demo_data'] is None:
|
| 110 |
+
return "β No data", "β No data", "β No data", None, None, None, None, None, None, "No EEG data available"
|
| 111 |
+
|
| 112 |
+
# Get current target
|
| 113 |
+
target_movement = sound_manager.get_current_target_movement()
|
| 114 |
+
print(f"DEBUG start_composition: current target = {target_movement}")
|
| 115 |
+
|
| 116 |
+
# Check if cycle is complete
|
| 117 |
+
if target_movement == "cycle_complete":
|
| 118 |
+
return "π΅ Cycle Complete!", "π΅ Complete", "Remap sounds to continue", None, None, None, None, None, None, "Cycle complete - remap sounds to continue"
|
| 119 |
+
|
| 120 |
+
# Perform initial EEG classification
|
| 121 |
+
epoch_data, true_label = data_processor.simulate_real_time_data(
|
| 122 |
+
app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Classify the epoch
|
| 126 |
+
predicted_class, confidence, probabilities = classifier.predict(epoch_data)
|
| 127 |
+
predicted_name = classifier.class_names[predicted_class]
|
| 128 |
+
|
| 129 |
+
# Process classification
|
| 130 |
+
result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
|
| 131 |
+
|
| 132 |
+
# Create visualization
|
| 133 |
+
fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
|
| 134 |
+
|
| 135 |
+
# Initialize all audio components to silent (to clear any previous sounds)
|
| 136 |
+
silent_file = "silent.wav"
|
| 137 |
+
left_hand_audio = silent_file
|
| 138 |
+
right_hand_audio = silent_file
|
| 139 |
+
left_leg_audio = silent_file
|
| 140 |
+
right_leg_audio = silent_file
|
| 141 |
+
tongue_audio = silent_file
|
| 142 |
+
|
| 143 |
+
# Debug: Print classification result
|
| 144 |
+
print(f"DEBUG start_composition: predicted={predicted_name}, confidence={confidence:.3f}, sound_added={result['sound_added']}")
|
| 145 |
+
|
| 146 |
+
# Only play the sound if it was just added and matches the prediction
|
| 147 |
+
if result['sound_added']:
|
| 148 |
+
sounds = get_movement_sounds()
|
| 149 |
+
print(f"DEBUG: Available sounds: {list(sounds.keys())}")
|
| 150 |
+
if predicted_name == 'left_hand' and 'left_hand' in sounds:
|
| 151 |
+
left_hand_audio = sounds['left_hand']
|
| 152 |
+
print(f"DEBUG: Setting left_hand_audio to {sounds['left_hand']}")
|
| 153 |
+
elif predicted_name == 'right_hand' and 'right_hand' in sounds:
|
| 154 |
+
right_hand_audio = sounds['right_hand']
|
| 155 |
+
print(f"DEBUG: Setting right_hand_audio to {sounds['right_hand']}")
|
| 156 |
+
elif predicted_name == 'left_leg' and 'left_leg' in sounds:
|
| 157 |
+
left_leg_audio = sounds['left_leg']
|
| 158 |
+
print(f"DEBUG: Setting left_leg_audio to {sounds['left_leg']}")
|
| 159 |
+
elif predicted_name == 'right_leg' and 'right_leg' in sounds:
|
| 160 |
+
right_leg_audio = sounds['right_leg']
|
| 161 |
+
print(f"DEBUG: Setting right_leg_audio to {sounds['right_leg']}")
|
| 162 |
+
elif predicted_name == 'tongue' and 'tongue' in sounds:
|
| 163 |
+
tongue_audio = sounds['tongue']
|
| 164 |
+
print(f"DEBUG: Setting tongue_audio to {sounds['tongue']}")
|
| 165 |
+
else:
|
| 166 |
+
print("DEBUG: No sound added - confidence too low or other issue")
|
| 167 |
+
|
| 168 |
+
# Format next target with progress information
|
| 169 |
+
next_target = sound_manager.get_current_target_movement()
|
| 170 |
+
completed_count = len(sound_manager.movements_completed)
|
| 171 |
+
total_count = len(sound_manager.current_movement_sequence)
|
| 172 |
+
|
| 173 |
+
if next_target == "cycle_complete":
|
| 174 |
+
target_text = "π΅ Cycle Complete!"
|
| 175 |
+
else:
|
| 176 |
+
target_text = f"π― Any Movement ({completed_count}/{total_count} complete) - Use 'Classify Epoch' button to continue"
|
| 177 |
+
|
| 178 |
+
predicted_text = f"π§ Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
|
| 179 |
+
|
| 180 |
+
# Get composition info
|
| 181 |
+
composition_info = sound_manager.get_composition_info()
|
| 182 |
+
status_text = format_composition_summary(composition_info)
|
| 183 |
+
|
| 184 |
+
return (
|
| 185 |
+
target_text,
|
| 186 |
+
predicted_text,
|
| 187 |
+
"2-3 seconds",
|
| 188 |
+
fig,
|
| 189 |
+
left_hand_audio,
|
| 190 |
+
right_hand_audio,
|
| 191 |
+
left_leg_audio,
|
| 192 |
+
right_leg_audio,
|
| 193 |
+
tongue_audio,
|
| 194 |
+
status_text
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def stop_composition():
|
| 198 |
+
"""Stop the composition process."""
|
| 199 |
+
global app_state
|
| 200 |
+
app_state['composition_active'] = False
|
| 201 |
+
app_state['auto_mode'] = False
|
| 202 |
+
return (
|
| 203 |
+
"Composition stopped. Click 'Start Composing' to begin again",
|
| 204 |
+
"--",
|
| 205 |
+
"--",
|
| 206 |
+
"Stopped - click Start to resume"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def start_automatic_composition():
|
| 210 |
+
"""Start automatic composition with continuous classification."""
|
| 211 |
+
global app_state
|
| 212 |
+
|
| 213 |
+
# Only start new cycle if not already active
|
| 214 |
+
if not app_state['composition_active']:
|
| 215 |
+
app_state['composition_active'] = True
|
| 216 |
+
app_state['auto_mode'] = True
|
| 217 |
+
sound_manager.start_new_cycle() # Reset composition only when starting fresh
|
| 218 |
+
|
| 219 |
+
if app_state['demo_data'] is None:
|
| 220 |
+
return "β No data", "β No data", "β No data", "β No data", None, None, None, None, None, None, "No EEG data available"
|
| 221 |
+
|
| 222 |
+
# Get current target
|
| 223 |
+
target_movement = sound_manager.get_current_target_movement()
|
| 224 |
+
print(f"DEBUG start_automatic_composition: current target = {target_movement}")
|
| 225 |
+
|
| 226 |
+
# Check if cycle is complete
|
| 227 |
+
if target_movement == "cycle_complete":
|
| 228 |
+
# Mark current cycle as complete
|
| 229 |
+
sound_manager.complete_current_cycle()
|
| 230 |
+
|
| 231 |
+
# Check if rehabilitation session should end
|
| 232 |
+
if sound_manager.should_end_session():
|
| 233 |
+
app_state['auto_mode'] = False # Stop automatic mode
|
| 234 |
+
return (
|
| 235 |
+
"π Session Complete!",
|
| 236 |
+
"π Amazing Progress!",
|
| 237 |
+
"Rehabilitation session finished!",
|
| 238 |
+
"π Congratulations! You've created 2 unique brain-music compositions!\n\n" +
|
| 239 |
+
"πͺ Your motor imagery skills are improving!\n\n" +
|
| 240 |
+
"π΅ You can review your compositions above, or start a new session anytime.\n\n" +
|
| 241 |
+
"Would you like to continue with more cycles, or take a well-deserved break?",
|
| 242 |
+
None, None, None, None, None, None,
|
| 243 |
+
f"β
Session Complete: {sound_manager.completed_cycles}/{sound_manager.max_cycles} compositions finished!"
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
# Start next cycle automatically
|
| 247 |
+
sound_manager.start_new_cycle()
|
| 248 |
+
print("π Cycle completed! Starting new cycle automatically...")
|
| 249 |
+
target_movement = sound_manager.get_current_target_movement() # Get new target
|
| 250 |
+
|
| 251 |
+
# Show user prompt - encouraging start message
|
| 252 |
+
cycle_num = sound_manager.current_cycle
|
| 253 |
+
if cycle_num == 1:
|
| 254 |
+
prompt_text = "π Welcome to your rehabilitation session! Let's start with any movement you can imagine..."
|
| 255 |
+
elif cycle_num == 2:
|
| 256 |
+
prompt_text = "πͺ Excellent work on your first composition! Ready for composition #2?"
|
| 257 |
+
else:
|
| 258 |
+
prompt_text = "π§ Let's continue - imagine any movement now..."
|
| 259 |
+
|
| 260 |
+
# Perform initial EEG classification
|
| 261 |
+
epoch_data, true_label = data_processor.simulate_real_time_data(
|
| 262 |
+
app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Classify the epoch
|
| 266 |
+
predicted_class, confidence, probabilities = classifier.predict(epoch_data)
|
| 267 |
+
predicted_name = classifier.class_names[predicted_class]
|
| 268 |
+
|
| 269 |
+
# Handle DJ effects or building phase
|
| 270 |
+
if sound_manager.current_phase == "dj_effects" and confidence > CONFIDENCE_THRESHOLD:
|
| 271 |
+
# DJ Effects Mode - toggle effects instead of adding sounds
|
| 272 |
+
dj_result = sound_manager.toggle_dj_effect(predicted_name)
|
| 273 |
+
result = {
|
| 274 |
+
'sound_added': dj_result['effect_applied'],
|
| 275 |
+
'mixed_composition': dj_result.get('mixed_composition'),
|
| 276 |
+
'effect_name': dj_result.get('effect_name', ''),
|
| 277 |
+
'effect_status': dj_result.get('effect_status', '')
|
| 278 |
+
}
|
| 279 |
+
else:
|
| 280 |
+
# Building Mode - process classification normally
|
| 281 |
+
result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
|
| 282 |
+
|
| 283 |
+
# Create visualization
|
| 284 |
+
fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
|
| 285 |
+
|
| 286 |
+
# Initialize all audio components to silent by default
|
| 287 |
+
silent_file = "silent.wav"
|
| 288 |
+
left_hand_audio = silent_file
|
| 289 |
+
right_hand_audio = silent_file
|
| 290 |
+
left_leg_audio = silent_file
|
| 291 |
+
right_leg_audio = silent_file
|
| 292 |
+
tongue_audio = silent_file
|
| 293 |
+
|
| 294 |
+
# Debug: Print classification result
|
| 295 |
+
print(f"DEBUG start_automatic_composition: predicted={predicted_name}, confidence={confidence:.3f}, sound_added={result['sound_added']}")
|
| 296 |
+
|
| 297 |
+
# Handle audio display based on current phase
|
| 298 |
+
if sound_manager.current_phase == "dj_effects":
|
| 299 |
+
# DJ Effects Phase - show mixed composition with effects
|
| 300 |
+
if result.get('mixed_composition'):
|
| 301 |
+
# Show the mixed composition (all sounds combined) in one player
|
| 302 |
+
left_hand_audio = result['mixed_composition'] # Use first player for mixed audio
|
| 303 |
+
print(f"DEBUG DJ: Playing mixed composition with effects: {result['mixed_composition']}")
|
| 304 |
+
else:
|
| 305 |
+
# Fallback to showing accumulated sounds
|
| 306 |
+
sounds = get_movement_sounds()
|
| 307 |
+
completed_movements = sound_manager.movements_completed
|
| 308 |
+
if 'left_hand' in completed_movements and 'left_hand' in sounds:
|
| 309 |
+
left_hand_audio = sounds['left_hand']
|
| 310 |
+
else:
|
| 311 |
+
# Building Phase - show ALL accumulated sounds (layered composition)
|
| 312 |
+
sounds = get_movement_sounds()
|
| 313 |
+
completed_movements = sound_manager.movements_completed
|
| 314 |
+
print(f"DEBUG: Available sounds: {list(sounds.keys())}")
|
| 315 |
+
print(f"DEBUG: Completed movements: {completed_movements}")
|
| 316 |
+
|
| 317 |
+
# Display all completed movement sounds (cumulative layering)
|
| 318 |
+
if 'left_hand' in completed_movements and 'left_hand' in sounds:
|
| 319 |
+
left_hand_audio = sounds['left_hand']
|
| 320 |
+
print(f"DEBUG: Showing accumulated left_hand_audio: {sounds['left_hand']}")
|
| 321 |
+
if 'right_hand' in completed_movements and 'right_hand' in sounds:
|
| 322 |
+
right_hand_audio = sounds['right_hand']
|
| 323 |
+
print(f"DEBUG: Showing accumulated right_hand_audio: {sounds['right_hand']}")
|
| 324 |
+
if 'left_leg' in completed_movements and 'left_leg' in sounds:
|
| 325 |
+
left_leg_audio = sounds['left_leg']
|
| 326 |
+
print(f"DEBUG: Showing accumulated left_leg_audio: {sounds['left_leg']}")
|
| 327 |
+
if 'right_leg' in completed_movements and 'right_leg' in sounds:
|
| 328 |
+
right_leg_audio = sounds['right_leg']
|
| 329 |
+
print(f"DEBUG: Showing accumulated right_leg_audio: {sounds['right_leg']}")
|
| 330 |
+
if 'tongue' in completed_movements and 'tongue' in sounds:
|
| 331 |
+
tongue_audio = sounds['tongue']
|
| 332 |
+
print(f"DEBUG: Showing accumulated tongue_audio: {sounds['tongue']}")
|
| 333 |
+
|
| 334 |
+
# If a sound was just added, make sure it's included immediately
|
| 335 |
+
if result['sound_added'] and predicted_name in sounds:
|
| 336 |
+
if predicted_name == 'left_hand':
|
| 337 |
+
left_hand_audio = sounds['left_hand']
|
| 338 |
+
elif predicted_name == 'right_hand':
|
| 339 |
+
right_hand_audio = sounds['right_hand']
|
| 340 |
+
elif predicted_name == 'left_leg':
|
| 341 |
+
left_leg_audio = sounds['left_leg']
|
| 342 |
+
elif predicted_name == 'right_leg':
|
| 343 |
+
right_leg_audio = sounds['right_leg']
|
| 344 |
+
elif predicted_name == 'tongue':
|
| 345 |
+
tongue_audio = sounds['tongue']
|
| 346 |
+
print(f"DEBUG: Just added {predicted_name} sound: {sounds[predicted_name]}")
|
| 347 |
+
|
| 348 |
+
# Check for phase transition to DJ effects
|
| 349 |
+
completed_count = len(sound_manager.movements_completed)
|
| 350 |
+
total_count = len(sound_manager.current_movement_sequence)
|
| 351 |
+
|
| 352 |
+
# Transition to DJ effects if all movements completed but still in building phase
|
| 353 |
+
if completed_count >= 5 and sound_manager.current_phase == "building":
|
| 354 |
+
sound_manager.transition_to_dj_phase()
|
| 355 |
+
|
| 356 |
+
# Format display based on current phase
|
| 357 |
+
if sound_manager.current_phase == "dj_effects":
|
| 358 |
+
target_text = "π§ DJ Mode Active - Use movements to control effects!"
|
| 359 |
+
else:
|
| 360 |
+
next_target = sound_manager.get_current_target_movement()
|
| 361 |
+
if next_target == "cycle_complete":
|
| 362 |
+
target_text = "π΅ Composition Complete!"
|
| 363 |
+
else:
|
| 364 |
+
target_text = f"π― Building Composition ({completed_count}/{total_count} layers)"
|
| 365 |
+
|
| 366 |
+
# Update display text based on phase
|
| 367 |
+
if sound_manager.current_phase == "dj_effects":
|
| 368 |
+
if result.get('effect_name') and result.get('effect_status'):
|
| 369 |
+
predicted_text = f"ποΈ {result['effect_name']}: {result['effect_status']}"
|
| 370 |
+
else:
|
| 371 |
+
predicted_text = f"π§ Detected: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
|
| 372 |
+
timer_text = "π§ DJ Mode - Effects updating every 3 seconds..."
|
| 373 |
+
else:
|
| 374 |
+
predicted_text = f"π§ Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
|
| 375 |
+
timer_text = "β±οΈ Next trial in 2-3 seconds..."
|
| 376 |
+
|
| 377 |
+
# Get composition info
|
| 378 |
+
composition_info = sound_manager.get_composition_info()
|
| 379 |
+
status_text = format_composition_summary(composition_info)
|
| 380 |
+
|
| 381 |
+
# Phase-based instruction visibility
|
| 382 |
+
building_visible = sound_manager.current_phase == "building"
|
| 383 |
+
dj_visible = sound_manager.current_phase == "dj_effects"
|
| 384 |
+
|
| 385 |
+
return (
|
| 386 |
+
target_text,
|
| 387 |
+
predicted_text,
|
| 388 |
+
timer_text,
|
| 389 |
+
prompt_text,
|
| 390 |
+
fig,
|
| 391 |
+
left_hand_audio,
|
| 392 |
+
right_hand_audio,
|
| 393 |
+
left_leg_audio,
|
| 394 |
+
right_leg_audio,
|
| 395 |
+
tongue_audio,
|
| 396 |
+
status_text,
|
| 397 |
+
gr.update(visible=building_visible), # building_instructions
|
| 398 |
+
gr.update(visible=dj_visible) # dj_instructions
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
def manual_classify():
|
| 402 |
+
"""Manual classification for testing purposes."""
|
| 403 |
+
global app_state
|
| 404 |
+
|
| 405 |
+
if app_state['demo_data'] is None:
|
| 406 |
+
return "β No data", "β No data", "Manual mode", None, "No EEG data available", None, None, None, None, None
|
| 407 |
+
|
| 408 |
+
# Get EEG data sample
|
| 409 |
+
epoch_data, true_label = data_processor.simulate_real_time_data(
|
| 410 |
+
app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Classify the epoch
|
| 414 |
+
predicted_class, confidence, probabilities = classifier.predict(epoch_data)
|
| 415 |
+
predicted_name = classifier.class_names[predicted_class]
|
| 416 |
+
|
| 417 |
+
# Create visualization (without composition context)
|
| 418 |
+
fig = create_eeg_plot(epoch_data, "manual_test", predicted_name, confidence, False)
|
| 419 |
+
|
| 420 |
+
# Format results
|
| 421 |
+
target_text = "π― Manual Test Mode"
|
| 422 |
+
predicted_text = f"π§ {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
|
| 423 |
+
|
| 424 |
+
# Update results log
|
| 425 |
+
import time
|
| 426 |
+
timestamp = time.strftime("%H:%M:%S")
|
| 427 |
+
result_entry = f"[{timestamp}] Predicted: {predicted_name.replace('_', ' ').title()} (confidence: {confidence:.3f})"
|
| 428 |
+
|
| 429 |
+
# Get sound files for preview (no autoplay)
|
| 430 |
+
sounds = get_movement_sounds()
|
| 431 |
+
left_hand_audio = sounds.get('left_hand', None)
|
| 432 |
+
right_hand_audio = sounds.get('right_hand', None)
|
| 433 |
+
left_leg_audio = sounds.get('left_leg', None)
|
| 434 |
+
right_leg_audio = sounds.get('right_leg', None)
|
| 435 |
+
tongue_audio = sounds.get('tongue', None)
|
| 436 |
+
|
| 437 |
+
return (
|
| 438 |
+
target_text,
|
| 439 |
+
predicted_text,
|
| 440 |
+
"Manual mode - click button to classify",
|
| 441 |
+
fig,
|
| 442 |
+
result_entry,
|
| 443 |
+
left_hand_audio,
|
| 444 |
+
right_hand_audio,
|
| 445 |
+
left_leg_audio,
|
| 446 |
+
right_leg_audio,
|
| 447 |
+
tongue_audio
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
def clear_manual():
|
| 451 |
+
"""Clear manual testing results."""
|
| 452 |
+
return (
|
| 453 |
+
"π― Manual Test Mode",
|
| 454 |
+
"--",
|
| 455 |
+
"Manual mode",
|
| 456 |
+
None,
|
| 457 |
+
"Manual classification results cleared...",
|
| 458 |
+
None, None, None, None, None
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
def continue_automatic_composition():
|
| 462 |
+
"""Continue automatic composition - called for subsequent trials."""
|
| 463 |
+
global app_state
|
| 464 |
+
|
| 465 |
+
if not app_state['composition_active'] or not app_state['auto_mode']:
|
| 466 |
+
return "π Stopped", "--", "--", "Automatic composition stopped", None, None, None, None, None, None, "Stopped", gr.update(visible=True), gr.update(visible=False)
|
| 467 |
+
|
| 468 |
+
if app_state['demo_data'] is None:
|
| 469 |
+
return "β No data", "β No data", "β No data", "β No data", None, None, None, None, None, None, "No EEG data available", gr.update(visible=True), gr.update(visible=False)
|
| 470 |
+
|
| 471 |
+
# Get current target
|
| 472 |
+
target_movement = sound_manager.get_current_target_movement()
|
| 473 |
+
print(f"DEBUG continue_automatic_composition: current target = {target_movement}")
|
| 474 |
+
|
| 475 |
+
# Check if cycle is complete
|
| 476 |
+
if target_movement == "cycle_complete":
|
| 477 |
+
# Mark current cycle as complete
|
| 478 |
+
sound_manager.complete_current_cycle()
|
| 479 |
+
|
| 480 |
+
# Check if rehabilitation session should end
|
| 481 |
+
if sound_manager.should_end_session():
|
| 482 |
+
app_state['auto_mode'] = False # Stop automatic mode
|
| 483 |
+
return (
|
| 484 |
+
"π Session Complete!",
|
| 485 |
+
"π Amazing Progress!",
|
| 486 |
+
"Rehabilitation session finished!",
|
| 487 |
+
"π Congratulations! You've created 2 unique brain-music compositions!\n\n" +
|
| 488 |
+
"πͺ Your motor imagery skills are improving!\n\n" +
|
| 489 |
+
"π΅ You can review your compositions above, or start a new session anytime.\n\n" +
|
| 490 |
+
"Would you like to continue with more cycles, or take a well-deserved break?",
|
| 491 |
+
None, None, None, None, None, None,
|
| 492 |
+
f"β
Session Complete: {sound_manager.completed_cycles}/{sound_manager.max_cycles} compositions finished!",
|
| 493 |
+
gr.update(visible=True), gr.update(visible=False)
|
| 494 |
+
)
|
| 495 |
+
else:
|
| 496 |
+
# Start next cycle automatically
|
| 497 |
+
sound_manager.start_new_cycle()
|
| 498 |
+
print("π Cycle completed! Starting new cycle automatically...")
|
| 499 |
+
target_movement = sound_manager.get_current_target_movement() # Get new target
|
| 500 |
+
|
| 501 |
+
# Show next user prompt - rehabilitation-focused messaging
|
| 502 |
+
prompts = [
|
| 503 |
+
"πͺ Great work! Imagine your next movement...",
|
| 504 |
+
"π― You're doing amazing! Focus and imagine any movement...",
|
| 505 |
+
"β¨ Excellent progress! Ready for the next movement?",
|
| 506 |
+
"π Keep it up! Concentrate and imagine now...",
|
| 507 |
+
"π Fantastic! Next trial - imagine any movement..."
|
| 508 |
+
]
|
| 509 |
+
import random
|
| 510 |
+
prompt_text = random.choice(prompts)
|
| 511 |
+
|
| 512 |
+
# Add progress encouragement
|
| 513 |
+
completed_count = len(sound_manager.movements_completed)
|
| 514 |
+
total_count = len(sound_manager.current_movement_sequence)
|
| 515 |
+
if completed_count > 0:
|
| 516 |
+
prompt_text += f" ({completed_count}/{total_count} movements completed this cycle)"
|
| 517 |
+
|
| 518 |
+
# Perform EEG classification
|
| 519 |
+
epoch_data, true_label = data_processor.simulate_real_time_data(
|
| 520 |
+
app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# Classify the epoch
|
| 524 |
+
predicted_class, confidence, probabilities = classifier.predict(epoch_data)
|
| 525 |
+
predicted_name = classifier.class_names[predicted_class]
|
| 526 |
+
|
| 527 |
+
# Handle DJ effects or building phase
|
| 528 |
+
if sound_manager.current_phase == "dj_effects" and confidence > CONFIDENCE_THRESHOLD:
|
| 529 |
+
# DJ Effects Mode - toggle effects instead of adding sounds
|
| 530 |
+
dj_result = sound_manager.toggle_dj_effect(predicted_name)
|
| 531 |
+
result = {
|
| 532 |
+
'sound_added': dj_result['effect_applied'],
|
| 533 |
+
'mixed_composition': dj_result.get('mixed_composition'),
|
| 534 |
+
'effect_name': dj_result.get('effect_name', ''),
|
| 535 |
+
'effect_status': dj_result.get('effect_status', '')
|
| 536 |
+
}
|
| 537 |
+
else:
|
| 538 |
+
# Building Mode - process classification normally
|
| 539 |
+
result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
|
| 540 |
+
|
| 541 |
+
# Check if we should transition to DJ phase
|
| 542 |
+
completed_count = len(sound_manager.movements_completed)
|
| 543 |
+
if completed_count >= 5 and sound_manager.current_phase == "building":
|
| 544 |
+
if sound_manager.transition_to_dj_phase():
|
| 545 |
+
print(f"DEBUG: Successfully transitioned to DJ phase with {completed_count} completed movements")
|
| 546 |
+
|
| 547 |
+
# Create visualization
|
| 548 |
+
fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
|
| 549 |
+
|
| 550 |
+
# Initialize all audio components to silent by default
|
| 551 |
+
silent_file = "silent.wav"
|
| 552 |
+
left_hand_audio = silent_file
|
| 553 |
+
right_hand_audio = silent_file
|
| 554 |
+
left_leg_audio = silent_file
|
| 555 |
+
right_leg_audio = silent_file
|
| 556 |
+
tongue_audio = silent_file
|
| 557 |
+
|
| 558 |
+
# Handle audio differently based on phase
|
| 559 |
+
if sound_manager.current_phase == "dj_effects":
|
| 560 |
+
# DJ Mode: Play mixed composition with effects in the center position (tongue)
|
| 561 |
+
# Keep individual tracks silent to avoid overlapping audio
|
| 562 |
+
# Only update audio if the mixed composition file has actually changed
|
| 563 |
+
if sound_manager.mixed_composition_file and os.path.exists(sound_manager.mixed_composition_file):
|
| 564 |
+
# Only update if the file path has changed from last time
|
| 565 |
+
if app_state['last_audio_state']['tongue_audio'] != sound_manager.mixed_composition_file:
|
| 566 |
+
tongue_audio = sound_manager.mixed_composition_file
|
| 567 |
+
app_state['last_audio_state']['tongue_audio'] = sound_manager.mixed_composition_file
|
| 568 |
+
print(f"DEBUG continue: DJ mode - NEW mixed composition loaded: {sound_manager.mixed_composition_file}")
|
| 569 |
+
elif app_state['last_audio_state']['tongue_audio'] is not None:
|
| 570 |
+
# Use the same file as before to prevent Gradio from restarting audio
|
| 571 |
+
tongue_audio = app_state['last_audio_state']['tongue_audio']
|
| 572 |
+
# No debug print to reduce spam
|
| 573 |
+
else:
|
| 574 |
+
# Handle case where cached state is None - use the current file
|
| 575 |
+
tongue_audio = sound_manager.mixed_composition_file
|
| 576 |
+
app_state['last_audio_state']['tongue_audio'] = sound_manager.mixed_composition_file
|
| 577 |
+
print(f"DEBUG continue: DJ mode - Loading mixed composition (state was None): {sound_manager.mixed_composition_file}")
|
| 578 |
+
# Note: We don't update with processed effects files to prevent audio restarts
|
| 579 |
+
else:
|
| 580 |
+
# Building Mode: Show ALL accumulated sounds (layered composition)
|
| 581 |
+
sounds = get_movement_sounds()
|
| 582 |
+
completed_movements = sound_manager.movements_completed
|
| 583 |
+
print(f"DEBUG continue: Available sounds: {list(sounds.keys())}")
|
| 584 |
+
print(f"DEBUG continue: Completed movements: {completed_movements}")
|
| 585 |
+
|
| 586 |
+
# Display all completed movement sounds (cumulative layering) - only update when changed
|
| 587 |
+
if 'left_hand' in completed_movements and 'left_hand' in sounds:
|
| 588 |
+
new_left_hand = sounds['left_hand']
|
| 589 |
+
if app_state['last_audio_state']['left_hand_audio'] != new_left_hand:
|
| 590 |
+
left_hand_audio = new_left_hand
|
| 591 |
+
app_state['last_audio_state']['left_hand_audio'] = new_left_hand
|
| 592 |
+
print(f"DEBUG continue: NEW left_hand_audio: {new_left_hand}")
|
| 593 |
+
elif app_state['last_audio_state']['left_hand_audio'] is not None:
|
| 594 |
+
left_hand_audio = app_state['last_audio_state']['left_hand_audio']
|
| 595 |
+
else:
|
| 596 |
+
# Handle case where cached state is None - use the current file
|
| 597 |
+
left_hand_audio = new_left_hand
|
| 598 |
+
app_state['last_audio_state']['left_hand_audio'] = new_left_hand
|
| 599 |
+
|
| 600 |
+
if 'right_hand' in completed_movements and 'right_hand' in sounds:
|
| 601 |
+
new_right_hand = sounds['right_hand']
|
| 602 |
+
# TEMP FIX: Always update right_hand to test audio issue
|
| 603 |
+
right_hand_audio = new_right_hand
|
| 604 |
+
app_state['last_audio_state']['right_hand_audio'] = new_right_hand
|
| 605 |
+
print(f"DEBUG continue: ALWAYS UPDATE right_hand_audio: {new_right_hand}")
|
| 606 |
+
|
| 607 |
+
if 'left_leg' in completed_movements and 'left_leg' in sounds:
|
| 608 |
+
new_left_leg = sounds['left_leg']
|
| 609 |
+
if app_state['last_audio_state']['left_leg_audio'] != new_left_leg:
|
| 610 |
+
left_leg_audio = new_left_leg
|
| 611 |
+
app_state['last_audio_state']['left_leg_audio'] = new_left_leg
|
| 612 |
+
print(f"DEBUG continue: NEW left_leg_audio: {new_left_leg}")
|
| 613 |
+
elif app_state['last_audio_state']['left_leg_audio'] is not None:
|
| 614 |
+
left_leg_audio = app_state['last_audio_state']['left_leg_audio']
|
| 615 |
+
else:
|
| 616 |
+
# Handle case where cached state is None - use the current file
|
| 617 |
+
left_leg_audio = new_left_leg
|
| 618 |
+
app_state['last_audio_state']['left_leg_audio'] = new_left_leg
|
| 619 |
+
|
| 620 |
+
if 'right_leg' in completed_movements and 'right_leg' in sounds:
|
| 621 |
+
new_right_leg = sounds['right_leg']
|
| 622 |
+
if app_state['last_audio_state']['right_leg_audio'] != new_right_leg:
|
| 623 |
+
right_leg_audio = new_right_leg
|
| 624 |
+
app_state['last_audio_state']['right_leg_audio'] = new_right_leg
|
| 625 |
+
print(f"DEBUG continue: NEW right_leg_audio: {new_right_leg}")
|
| 626 |
+
elif app_state['last_audio_state']['right_leg_audio'] is not None:
|
| 627 |
+
right_leg_audio = app_state['last_audio_state']['right_leg_audio']
|
| 628 |
+
else:
|
| 629 |
+
# Handle case where cached state is None - use the current file
|
| 630 |
+
right_leg_audio = new_right_leg
|
| 631 |
+
app_state['last_audio_state']['right_leg_audio'] = new_right_leg
|
| 632 |
+
|
| 633 |
+
if 'tongue' in completed_movements and 'tongue' in sounds:
|
| 634 |
+
new_tongue = sounds['tongue']
|
| 635 |
+
# Note: Don't update tongue audio in building mode if we're about to transition to DJ mode
|
| 636 |
+
if sound_manager.current_phase == "building":
|
| 637 |
+
if app_state['last_audio_state']['tongue_audio'] != new_tongue:
|
| 638 |
+
tongue_audio = new_tongue
|
| 639 |
+
app_state['last_audio_state']['tongue_audio'] = new_tongue
|
| 640 |
+
print(f"DEBUG continue: NEW tongue_audio: {new_tongue}")
|
| 641 |
+
elif app_state['last_audio_state']['tongue_audio'] is not None:
|
| 642 |
+
tongue_audio = app_state['last_audio_state']['tongue_audio']
|
| 643 |
+
else:
|
| 644 |
+
# Handle case where cached state is None - use the current file
|
| 645 |
+
tongue_audio = new_tongue
|
| 646 |
+
app_state['last_audio_state']['tongue_audio'] = new_tongue
|
| 647 |
+
|
| 648 |
+
# If a sound was just added, make sure it gets updated immediately (override state check)
|
| 649 |
+
if result['sound_added'] and predicted_name in sounds:
|
| 650 |
+
new_sound = sounds[predicted_name]
|
| 651 |
+
if predicted_name == 'left_hand':
|
| 652 |
+
left_hand_audio = new_sound
|
| 653 |
+
app_state['last_audio_state']['left_hand_audio'] = new_sound
|
| 654 |
+
elif predicted_name == 'right_hand':
|
| 655 |
+
right_hand_audio = new_sound
|
| 656 |
+
app_state['last_audio_state']['right_hand_audio'] = new_sound
|
| 657 |
+
elif predicted_name == 'left_leg':
|
| 658 |
+
left_leg_audio = new_sound
|
| 659 |
+
app_state['last_audio_state']['left_leg_audio'] = new_sound
|
| 660 |
+
elif predicted_name == 'right_leg':
|
| 661 |
+
right_leg_audio = new_sound
|
| 662 |
+
app_state['last_audio_state']['right_leg_audio'] = new_sound
|
| 663 |
+
elif predicted_name == 'tongue':
|
| 664 |
+
tongue_audio = new_sound
|
| 665 |
+
app_state['last_audio_state']['tongue_audio'] = new_sound
|
| 666 |
+
print(f"DEBUG continue: Force update - just added {predicted_name} sound: {new_sound}")
|
| 667 |
+
|
| 668 |
+
# Format display with progress information
|
| 669 |
+
completed_count = len(sound_manager.movements_completed)
|
| 670 |
+
total_count = len(sound_manager.current_movement_sequence)
|
| 671 |
+
|
| 672 |
+
if sound_manager.current_phase == "dj_effects":
|
| 673 |
+
target_text = f"π§ DJ Mode - Control Effects with Movements"
|
| 674 |
+
predicted_text = f"π§ Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
|
| 675 |
+
if result.get('effect_applied'):
|
| 676 |
+
effect_name = result.get('effect_name', '')
|
| 677 |
+
effect_status = result.get('effect_status', '')
|
| 678 |
+
timer_text = f"οΏ½οΈ {effect_name}: {effect_status}"
|
| 679 |
+
else:
|
| 680 |
+
timer_text = "π΅ Move to control effects..."
|
| 681 |
+
else:
|
| 682 |
+
target_text = f"οΏ½π― Any Movement ({completed_count}/{total_count} complete)"
|
| 683 |
+
predicted_text = f"π§ Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
|
| 684 |
+
timer_text = "β±οΈ Next trial in 2-3 seconds..." if app_state['auto_mode'] else "Stopped"
|
| 685 |
+
|
| 686 |
+
# Get composition info
|
| 687 |
+
composition_info = sound_manager.get_composition_info()
|
| 688 |
+
status_text = format_composition_summary(composition_info)
|
| 689 |
+
|
| 690 |
+
# Phase-based instruction visibility
|
| 691 |
+
building_visible = sound_manager.current_phase == "building"
|
| 692 |
+
dj_visible = sound_manager.current_phase == "dj_effects"
|
| 693 |
+
|
| 694 |
+
return (
|
| 695 |
+
target_text,
|
| 696 |
+
predicted_text,
|
| 697 |
+
timer_text,
|
| 698 |
+
prompt_text,
|
| 699 |
+
fig,
|
| 700 |
+
left_hand_audio,
|
| 701 |
+
right_hand_audio,
|
| 702 |
+
left_leg_audio,
|
| 703 |
+
right_leg_audio,
|
| 704 |
+
tongue_audio,
|
| 705 |
+
status_text,
|
| 706 |
+
gr.update(visible=building_visible), # building_instructions
|
| 707 |
+
gr.update(visible=dj_visible) # dj_instructions
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
def classify_epoch():
|
| 711 |
+
"""Classify a single EEG epoch and update composition."""
|
| 712 |
+
global app_state
|
| 713 |
+
|
| 714 |
+
if not app_state['composition_active']:
|
| 715 |
+
return "β Not active", "β Not active", "β Not active", None, None, None, None, None, None, "Click 'Start Composing' first"
|
| 716 |
+
|
| 717 |
+
if app_state['demo_data'] is None:
|
| 718 |
+
return "β No data", "β No data", "β No data", None, None, None, None, None, None, "No EEG data available"
|
| 719 |
+
|
| 720 |
+
# Get current target
|
| 721 |
+
target_movement = sound_manager.get_current_target_movement()
|
| 722 |
+
print(f"DEBUG classify_epoch: current target = {target_movement}")
|
| 723 |
+
|
| 724 |
+
if target_movement == "cycle_complete":
|
| 725 |
+
return "π΅ Cycle Complete!", "π΅ Complete", "Remap sounds to continue", None, None, None, None, None, None, "Cycle complete - remap sounds to continue"
|
| 726 |
+
|
| 727 |
+
# Get EEG data sample
|
| 728 |
+
epoch_data, true_label = data_processor.simulate_real_time_data(
|
| 729 |
+
app_state['demo_data'], app_state['demo_labels'], mode="class_balanced"
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
# Classify the epoch
|
| 733 |
+
predicted_class, confidence, probabilities = classifier.predict(epoch_data)
|
| 734 |
+
predicted_name = classifier.class_names[predicted_class]
|
| 735 |
+
|
| 736 |
+
# Process classification
|
| 737 |
+
result = sound_manager.process_classification(predicted_name, confidence, CONFIDENCE_THRESHOLD)
|
| 738 |
+
|
| 739 |
+
# Check if we should transition to DJ phase
|
| 740 |
+
completed_count = len(sound_manager.movements_completed)
|
| 741 |
+
if completed_count >= 5 and sound_manager.current_phase == "building":
|
| 742 |
+
if sound_manager.transition_to_dj_phase():
|
| 743 |
+
print(f"DEBUG: Successfully transitioned to DJ phase with {completed_count} completed movements")
|
| 744 |
+
|
| 745 |
+
# Create visualization
|
| 746 |
+
fig = create_eeg_plot(epoch_data, target_movement, predicted_name, confidence, result['sound_added'])
|
| 747 |
+
|
| 748 |
+
# Initialize all audio components to silent (to clear any previous sounds)
|
| 749 |
+
silent_file = "silent.wav"
|
| 750 |
+
left_hand_audio = silent_file
|
| 751 |
+
right_hand_audio = silent_file
|
| 752 |
+
left_leg_audio = silent_file
|
| 753 |
+
right_leg_audio = silent_file
|
| 754 |
+
tongue_audio = silent_file
|
| 755 |
+
|
| 756 |
+
# Only play the sound if it was just added and matches the prediction
|
| 757 |
+
if result['sound_added']:
|
| 758 |
+
sounds = get_movement_sounds()
|
| 759 |
+
if predicted_name == 'left_hand' and 'left_hand' in sounds:
|
| 760 |
+
left_hand_audio = sounds['left_hand']
|
| 761 |
+
elif predicted_name == 'right_hand' and 'right_hand' in sounds:
|
| 762 |
+
right_hand_audio = sounds['right_hand']
|
| 763 |
+
elif predicted_name == 'left_leg' and 'left_leg' in sounds:
|
| 764 |
+
left_leg_audio = sounds['left_leg']
|
| 765 |
+
elif predicted_name == 'right_leg' and 'right_leg' in sounds:
|
| 766 |
+
right_leg_audio = sounds['right_leg']
|
| 767 |
+
elif predicted_name == 'tongue' and 'tongue' in sounds:
|
| 768 |
+
tongue_audio = sounds['tongue']
|
| 769 |
+
|
| 770 |
+
# Format next target
|
| 771 |
+
next_target = sound_manager.get_current_target_movement()
|
| 772 |
+
target_text = f"π― Target: {next_target.replace('_', ' ').title()}" if next_target != "cycle_complete" else "π΅ Cycle Complete!"
|
| 773 |
+
|
| 774 |
+
predicted_text = f"π§ Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f})"
|
| 775 |
+
|
| 776 |
+
# Get composition info
|
| 777 |
+
composition_info = sound_manager.get_composition_info()
|
| 778 |
+
status_text = format_composition_summary(composition_info)
|
| 779 |
+
|
| 780 |
+
return (
|
| 781 |
+
target_text,
|
| 782 |
+
predicted_text,
|
| 783 |
+
"2-3 seconds",
|
| 784 |
+
fig,
|
| 785 |
+
left_hand_audio,
|
| 786 |
+
right_hand_audio,
|
| 787 |
+
left_leg_audio,
|
| 788 |
+
right_leg_audio,
|
| 789 |
+
tongue_audio,
|
| 790 |
+
status_text
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
def create_eeg_plot(eeg_data: np.ndarray, target_movement: str, predicted_name: str, confidence: float, sound_added: bool) -> plt.Figure:
|
| 794 |
+
"""Create EEG plot with target movement and classification result."""
|
| 795 |
+
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
| 796 |
+
axes = axes.flatten()
|
| 797 |
+
|
| 798 |
+
# Plot 4 channels
|
| 799 |
+
time_points = np.arange(eeg_data.shape[1]) / 200 # 200 Hz sampling rate
|
| 800 |
+
channel_names = ['C3', 'C4', 'T3', 'T4'] # Motor cortex channels
|
| 801 |
+
|
| 802 |
+
for i in range(min(4, eeg_data.shape[0])):
|
| 803 |
+
color = 'green' if sound_added else 'blue'
|
| 804 |
+
axes[i].plot(time_points, eeg_data[i], color=color, linewidth=1)
|
| 805 |
+
|
| 806 |
+
if i < len(channel_names):
|
| 807 |
+
axes[i].set_title(f'{channel_names[i]} (Ch {i+1})')
|
| 808 |
+
else:
|
| 809 |
+
axes[i].set_title(f'Channel {i+1}')
|
| 810 |
+
|
| 811 |
+
axes[i].set_xlabel('Time (s)')
|
| 812 |
+
axes[i].set_ylabel('Amplitude (Β΅V)')
|
| 813 |
+
axes[i].grid(True, alpha=0.3)
|
| 814 |
+
|
| 815 |
+
# Add overall title with status
|
| 816 |
+
status = "β SOUND ADDED" if sound_added else "β No sound"
|
| 817 |
+
title = f"Target: {target_movement.replace('_', ' ').title()} | Predicted: {predicted_name.replace('_', ' ').title()} ({confidence:.2f}) | {status}"
|
| 818 |
+
fig.suptitle(title, fontsize=12, fontweight='bold')
|
| 819 |
+
fig.tight_layout()
|
| 820 |
+
return fig
|
| 821 |
+
|
| 822 |
+
def format_composition_summary(composition_info: Dict) -> str:
|
| 823 |
+
"""Format composition information for display."""
|
| 824 |
+
if not composition_info.get('layers_by_cycle'):
|
| 825 |
+
return "No composition layers yet"
|
| 826 |
+
|
| 827 |
+
summary = []
|
| 828 |
+
for cycle, layers in composition_info['layers_by_cycle'].items():
|
| 829 |
+
summary.append(f"Cycle {cycle + 1}: {len(layers)} layers")
|
| 830 |
+
for layer in layers:
|
| 831 |
+
movement = layer.get('movement', 'unknown')
|
| 832 |
+
confidence = layer.get('confidence', 0)
|
| 833 |
+
summary.append(f" β’ {movement.replace('_', ' ').title()} ({confidence:.2f})")
|
| 834 |
+
|
| 835 |
+
return "\n".join(summary) if summary else "No composition layers"
|
| 836 |
+
|
| 837 |
+
# Create Gradio interface
|
| 838 |
+
def create_interface():
|
| 839 |
+
with gr.Blocks(title="EEG Motor Imagery Music Composer", theme=gr.themes.Soft()) as demo:
|
| 840 |
+
gr.Markdown("# π§ π΅ EEG Motor Imagery Rehabilitation Composer")
|
| 841 |
+
gr.Markdown("**Therapeutic Brain-Computer Interface for Motor Recovery**\n\nCreate beautiful music compositions using your brain signals! This rehabilitation tool helps strengthen motor imagery skills while creating personalized musical pieces.")
|
| 842 |
+
|
| 843 |
+
with gr.Tabs() as tabs:
|
| 844 |
+
# Main Composition Tab
|
| 845 |
+
with gr.TabItem("π΅ Automatic Composition"):
|
| 846 |
+
with gr.Row():
|
| 847 |
+
# Left side - Task and EEG information
|
| 848 |
+
with gr.Column(scale=2):
|
| 849 |
+
# Task instructions - Building Phase
|
| 850 |
+
with gr.Group() as building_instructions:
|
| 851 |
+
gr.Markdown("### π― Rehabilitation Session Instructions")
|
| 852 |
+
gr.Markdown("""
|
| 853 |
+
**Motor Imagery Training:**
|
| 854 |
+
- **Imagine** opening or closing your **right or left hand**
|
| 855 |
+
- **Visualize** briefly moving your **right or left leg or foot**
|
| 856 |
+
- **Think about** pronouncing **"L"** with your tongue
|
| 857 |
+
- **Rest state** (no movement imagination)
|
| 858 |
+
|
| 859 |
+
*π Each successful imagination creates a musical layer!*
|
| 860 |
+
|
| 861 |
+
**Session Structure:** Build composition, then control DJ effects
|
| 862 |
+
*Press Start to begin your personalized rehabilitation session*
|
| 863 |
+
""")
|
| 864 |
+
|
| 865 |
+
# DJ Instructions - Effects Phase (initially hidden)
|
| 866 |
+
with gr.Group(visible=False) as dj_instructions:
|
| 867 |
+
gr.Markdown("### π§ DJ Controller Mode")
|
| 868 |
+
gr.Markdown("""
|
| 869 |
+
**π Composition Complete! You are now the DJ!**
|
| 870 |
+
|
| 871 |
+
**Use the same movements to control audio effects:**
|
| 872 |
+
- π **Left Hand**: Volume Fade On/Off
|
| 873 |
+
- π **Right Hand**: High Pass Filter On/Off
|
| 874 |
+
- 𦡠**Left Leg**: Reverb Effect On/Off
|
| 875 |
+
- 𦡠**Right Leg**: Low Pass Filter On/Off
|
| 876 |
+
- π
**Tongue**: Bass Boost On/Off
|
| 877 |
+
|
| 878 |
+
*ποΈ Each movement toggles an effect - Mix your creation!*
|
| 879 |
+
""")
|
| 880 |
+
|
| 881 |
+
# Start button
|
| 882 |
+
with gr.Row():
|
| 883 |
+
start_btn = gr.Button("π΅ Start Composing", variant="primary", size="lg")
|
| 884 |
+
continue_btn = gr.Button("βοΈ Continue", variant="primary", size="lg", visible=False)
|
| 885 |
+
stop_btn = gr.Button("π Stop", variant="secondary", size="lg")
|
| 886 |
+
|
| 887 |
+
# Session completion options (shown after 2 cycles)
|
| 888 |
+
with gr.Row(visible=False) as session_complete_row:
|
| 889 |
+
new_session_btn = gr.Button("π Start New Session", variant="primary", size="lg")
|
| 890 |
+
extend_session_btn = gr.Button("β Continue Session", variant="secondary", size="lg")
|
| 891 |
+
|
| 892 |
+
# Timer for automatic progression (hidden from user)
|
| 893 |
+
timer = gr.Timer(value=3.0, active=False) # 3 second intervals
|
| 894 |
+
|
| 895 |
+
# User prompt display
|
| 896 |
+
user_prompt = gr.Textbox(label="π User Prompt", interactive=False, value="Click 'Start Composing' to begin",
|
| 897 |
+
elem_classes=["prompt-display"])
|
| 898 |
+
|
| 899 |
+
# Current status
|
| 900 |
+
with gr.Row():
|
| 901 |
+
target_display = gr.Textbox(label="π― Current Target", interactive=False, value="Ready to start")
|
| 902 |
+
predicted_display = gr.Textbox(label="π§ Predicted", interactive=False, value="--")
|
| 903 |
+
|
| 904 |
+
timer_display = gr.Textbox(label="β±οΈ Next Trial In", interactive=False, value="--")
|
| 905 |
+
|
| 906 |
+
eeg_plot = gr.Plot(label="EEG Data Visualization")
|
| 907 |
+
|
| 908 |
+
# Right side - Compositional layers
|
| 909 |
+
with gr.Column(scale=1):
|
| 910 |
+
gr.Markdown("### π΅ Compositional Layers")
|
| 911 |
+
|
| 912 |
+
# Show 5 movement sounds
|
| 913 |
+
left_hand_sound = gr.Audio(label="π Left Hand", interactive=False, autoplay=True, visible=True)
|
| 914 |
+
right_hand_sound = gr.Audio(label="π Right Hand", interactive=False, autoplay=True, visible=True)
|
| 915 |
+
left_leg_sound = gr.Audio(label="𦡠Left Leg", interactive=False, autoplay=True, visible=True)
|
| 916 |
+
right_leg_sound = gr.Audio(label="𦡠Right Leg", interactive=False, autoplay=True, visible=True)
|
| 917 |
+
tongue_sound = gr.Audio(label="π
Tongue", interactive=False, autoplay=True, visible=True)
|
| 918 |
+
|
| 919 |
+
# Composition status
|
| 920 |
+
composition_status = gr.Textbox(label="Composition Status", interactive=False, lines=5)
|
| 921 |
+
|
| 922 |
+
# Manual Testing Tab
|
| 923 |
+
with gr.TabItem("π§ Manual Testing"):
|
| 924 |
+
with gr.Row():
|
| 925 |
+
with gr.Column(scale=2):
|
| 926 |
+
gr.Markdown("### π¬ Manual EEG Classification Testing")
|
| 927 |
+
gr.Markdown("Use this tab to manually test the EEG classifier without the composition system.")
|
| 928 |
+
|
| 929 |
+
with gr.Row():
|
| 930 |
+
classify_btn = gr.Button("π§ Classify Single Epoch", variant="primary")
|
| 931 |
+
clear_btn = gr.Button("οΏ½οΈ Clear", variant="secondary")
|
| 932 |
+
|
| 933 |
+
# Manual status displays
|
| 934 |
+
manual_target_display = gr.Textbox(label="π― Current Target", interactive=False, value="Ready")
|
| 935 |
+
manual_predicted_display = gr.Textbox(label="π§ Predicted", interactive=False, value="--")
|
| 936 |
+
manual_timer_display = gr.Textbox(label="β±οΈ Status", interactive=False, value="Manual mode")
|
| 937 |
+
|
| 938 |
+
manual_eeg_plot = gr.Plot(label="EEG Data Visualization")
|
| 939 |
+
|
| 940 |
+
with gr.Column(scale=1):
|
| 941 |
+
gr.Markdown("### π Classification Results")
|
| 942 |
+
manual_results = gr.Textbox(label="Results Log", interactive=False, lines=10, value="Manual classification results will appear here...")
|
| 943 |
+
|
| 944 |
+
# Individual sound previews (no autoplay in manual mode)
|
| 945 |
+
gr.Markdown("### π Sound Preview")
|
| 946 |
+
manual_left_hand_sound = gr.Audio(label="π Left Hand", interactive=False, autoplay=False, visible=True)
|
| 947 |
+
manual_right_hand_sound = gr.Audio(label="π Right Hand", interactive=False, autoplay=False, visible=True)
|
| 948 |
+
manual_left_leg_sound = gr.Audio(label="𦡠Left Leg", interactive=False, autoplay=False, visible=True)
|
| 949 |
+
manual_right_leg_sound = gr.Audio(label="𦡠Right Leg", interactive=False, autoplay=False, visible=True)
|
| 950 |
+
manual_tongue_sound = gr.Audio(label="π
Tongue", interactive=False, autoplay=False, visible=True)
|
| 951 |
+
|
| 952 |
+
# Session management functions
|
| 953 |
+
def start_new_session():
|
| 954 |
+
"""Reset everything and start a completely new rehabilitation session"""
|
| 955 |
+
global sound_manager
|
| 956 |
+
sound_manager.completed_cycles = 0
|
| 957 |
+
sound_manager.current_cycle = 0
|
| 958 |
+
sound_manager.movements_completed = set()
|
| 959 |
+
sound_manager.composition_layers = []
|
| 960 |
+
|
| 961 |
+
# Start fresh session
|
| 962 |
+
result = start_automatic_composition()
|
| 963 |
+
return (
|
| 964 |
+
result[0], # target_display
|
| 965 |
+
result[1], # predicted_display
|
| 966 |
+
result[2], # timer_display
|
| 967 |
+
result[3], # user_prompt
|
| 968 |
+
result[4], # eeg_plot
|
| 969 |
+
result[5], # left_hand_sound
|
| 970 |
+
result[6], # right_hand_sound
|
| 971 |
+
result[7], # left_leg_sound
|
| 972 |
+
result[8], # right_leg_sound
|
| 973 |
+
result[9], # tongue_sound
|
| 974 |
+
result[10], # composition_status
|
| 975 |
+
result[11], # building_instructions
|
| 976 |
+
result[12], # dj_instructions
|
| 977 |
+
gr.update(visible=True), # continue_btn - show it
|
| 978 |
+
gr.update(active=True), # timer - activate it
|
| 979 |
+
gr.update(visible=False) # session_complete_row - hide it
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
def extend_current_session():
|
| 983 |
+
"""Continue current session beyond the 2-cycle limit"""
|
| 984 |
+
sound_manager.max_cycles += 2 # Add 2 more cycles
|
| 985 |
+
|
| 986 |
+
# Continue with current session
|
| 987 |
+
result = continue_automatic_composition()
|
| 988 |
+
return (
|
| 989 |
+
result[0], # target_display
|
| 990 |
+
result[1], # predicted_display
|
| 991 |
+
result[2], # timer_display
|
| 992 |
+
result[3], # user_prompt
|
| 993 |
+
result[4], # eeg_plot
|
| 994 |
+
result[5], # left_hand_sound
|
| 995 |
+
result[6], # right_hand_sound
|
| 996 |
+
result[7], # left_leg_sound
|
| 997 |
+
result[8], # right_leg_sound
|
| 998 |
+
result[9], # tongue_sound
|
| 999 |
+
result[10], # composition_status
|
| 1000 |
+
result[11], # building_instructions
|
| 1001 |
+
result[12], # dj_instructions
|
| 1002 |
+
gr.update(visible=True), # continue_btn - show it
|
| 1003 |
+
gr.update(active=True), # timer - activate it
|
| 1004 |
+
gr.update(visible=False) # session_complete_row - hide it
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
# Wrapper functions for timer control
|
| 1008 |
+
def start_with_timer():
|
| 1009 |
+
"""Start composition and activate automatic timer"""
|
| 1010 |
+
result = start_automatic_composition()
|
| 1011 |
+
# Show continue button and activate timer
|
| 1012 |
+
return (
|
| 1013 |
+
result[0], # target_display
|
| 1014 |
+
result[1], # predicted_display
|
| 1015 |
+
result[2], # timer_display
|
| 1016 |
+
result[3], # user_prompt
|
| 1017 |
+
result[4], # eeg_plot
|
| 1018 |
+
result[5], # left_hand_sound
|
| 1019 |
+
result[6], # right_hand_sound
|
| 1020 |
+
result[7], # left_leg_sound
|
| 1021 |
+
result[8], # right_leg_sound
|
| 1022 |
+
result[9], # tongue_sound
|
| 1023 |
+
result[10], # composition_status
|
| 1024 |
+
result[11], # building_instructions
|
| 1025 |
+
result[12], # dj_instructions
|
| 1026 |
+
gr.update(visible=True), # continue_btn - show it
|
| 1027 |
+
gr.update(active=True) # timer - activate it
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
def continue_with_timer():
|
| 1031 |
+
"""Continue composition and manage timer state"""
|
| 1032 |
+
result = continue_automatic_composition()
|
| 1033 |
+
|
| 1034 |
+
# Check if session is complete (rehabilitation session finished)
|
| 1035 |
+
if "π Session Complete!" in result[0]:
|
| 1036 |
+
# Show session completion options
|
| 1037 |
+
return (
|
| 1038 |
+
result[0], # target_display
|
| 1039 |
+
result[1], # predicted_display
|
| 1040 |
+
result[2], # timer_display
|
| 1041 |
+
result[3], # user_prompt
|
| 1042 |
+
result[4], # eeg_plot
|
| 1043 |
+
result[5], # left_hand_sound
|
| 1044 |
+
result[6], # right_hand_sound
|
| 1045 |
+
result[7], # left_leg_sound
|
| 1046 |
+
result[8], # right_leg_sound
|
| 1047 |
+
result[9], # tongue_sound
|
| 1048 |
+
result[10], # composition_status
|
| 1049 |
+
result[11], # building_instructions
|
| 1050 |
+
result[12], # dj_instructions
|
| 1051 |
+
gr.update(active=False), # timer - deactivate it
|
| 1052 |
+
gr.update(visible=True) # session_complete_row - show options
|
| 1053 |
+
)
|
| 1054 |
+
# Check if composition is complete (old logic for other cases)
|
| 1055 |
+
elif "π΅ Cycle Complete!" in result[0]:
|
| 1056 |
+
# Stop the timer when composition is complete
|
| 1057 |
+
return (
|
| 1058 |
+
result[0], # target_display
|
| 1059 |
+
result[1], # predicted_display
|
| 1060 |
+
result[2], # timer_display
|
| 1061 |
+
result[3], # user_prompt
|
| 1062 |
+
result[4], # eeg_plot
|
| 1063 |
+
result[5], # left_hand_sound
|
| 1064 |
+
result[6], # right_hand_sound
|
| 1065 |
+
result[7], # left_leg_sound
|
| 1066 |
+
result[8], # right_leg_sound
|
| 1067 |
+
result[9], # tongue_sound
|
| 1068 |
+
result[10], # composition_status
|
| 1069 |
+
result[11], # building_instructions
|
| 1070 |
+
result[12], # dj_instructions
|
| 1071 |
+
gr.update(active=False), # timer - deactivate it
|
| 1072 |
+
gr.update(visible=False) # session_complete_row - keep hidden
|
| 1073 |
+
)
|
| 1074 |
+
else:
|
| 1075 |
+
# Keep timer active for next iteration
|
| 1076 |
+
return (
|
| 1077 |
+
result[0], # target_display
|
| 1078 |
+
result[1], # predicted_display
|
| 1079 |
+
result[2], # timer_display
|
| 1080 |
+
result[3], # user_prompt
|
| 1081 |
+
result[4], # eeg_plot
|
| 1082 |
+
result[5], # left_hand_sound
|
| 1083 |
+
result[6], # right_hand_sound
|
| 1084 |
+
result[7], # left_leg_sound
|
| 1085 |
+
result[8], # right_leg_sound
|
| 1086 |
+
result[9], # tongue_sound
|
| 1087 |
+
result[10], # composition_status
|
| 1088 |
+
result[11], # building_instructions
|
| 1089 |
+
result[12], # dj_instructions
|
| 1090 |
+
gr.update(active=True), # timer - keep active
|
| 1091 |
+
gr.update(visible=False) # session_complete_row - keep hidden
|
| 1092 |
+
)
|
| 1093 |
+
|
| 1094 |
+
# Event handlers for automatic composition tab
|
| 1095 |
+
start_btn.click(
|
| 1096 |
+
fn=start_with_timer,
|
| 1097 |
+
outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
|
| 1098 |
+
left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, tongue_sound, composition_status,
|
| 1099 |
+
building_instructions, dj_instructions, continue_btn, timer]
|
| 1100 |
+
)
|
| 1101 |
+
|
| 1102 |
+
continue_btn.click(
|
| 1103 |
+
fn=continue_with_timer,
|
| 1104 |
+
outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
|
| 1105 |
+
left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, tongue_sound, composition_status,
|
| 1106 |
+
building_instructions, dj_instructions, timer, session_complete_row]
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
# Timer automatically triggers continuation
|
| 1110 |
+
timer.tick(
|
| 1111 |
+
fn=continue_with_timer,
|
| 1112 |
+
outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
|
| 1113 |
+
left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, tongue_sound, composition_status,
|
| 1114 |
+
building_instructions, dj_instructions, timer, session_complete_row]
|
| 1115 |
+
)
|
| 1116 |
+
|
| 1117 |
+
# Session completion event handlers
|
| 1118 |
+
new_session_btn.click(
|
| 1119 |
+
fn=start_new_session,
|
| 1120 |
+
outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
|
| 1121 |
+
left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, tongue_sound, composition_status,
|
| 1122 |
+
building_instructions, dj_instructions, continue_btn, timer, session_complete_row]
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
extend_session_btn.click(
|
| 1126 |
+
fn=extend_current_session,
|
| 1127 |
+
outputs=[target_display, predicted_display, timer_display, user_prompt, eeg_plot,
|
| 1128 |
+
left_hand_sound, right_hand_sound, left_leg_sound, right_leg_sound, tongue_sound, composition_status,
|
| 1129 |
+
building_instructions, dj_instructions, continue_btn, timer, session_complete_row]
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
def stop_with_timer():
|
| 1133 |
+
"""Stop composition and deactivate timer"""
|
| 1134 |
+
result = stop_composition()
|
| 1135 |
+
return (
|
| 1136 |
+
result[0], # target_display
|
| 1137 |
+
result[1], # predicted_display
|
| 1138 |
+
result[2], # timer_display
|
| 1139 |
+
result[3], # user_prompt
|
| 1140 |
+
gr.update(visible=False), # continue_btn - hide it
|
| 1141 |
+
gr.update(active=False) # timer - deactivate it
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
+
stop_btn.click(
|
| 1145 |
+
fn=stop_with_timer,
|
| 1146 |
+
outputs=[target_display, predicted_display, timer_display, user_prompt, continue_btn, timer]
|
| 1147 |
+
)
|
| 1148 |
+
|
| 1149 |
+
# Event handlers for manual testing tab
|
| 1150 |
+
classify_btn.click(
|
| 1151 |
+
fn=manual_classify,
|
| 1152 |
+
outputs=[manual_target_display, manual_predicted_display, manual_timer_display, manual_eeg_plot, manual_results,
|
| 1153 |
+
manual_left_hand_sound, manual_right_hand_sound, manual_left_leg_sound, manual_right_leg_sound, manual_tongue_sound]
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
clear_btn.click(
|
| 1157 |
+
fn=clear_manual,
|
| 1158 |
+
outputs=[manual_target_display, manual_predicted_display, manual_timer_display, manual_eeg_plot, manual_results,
|
| 1159 |
+
manual_left_hand_sound, manual_right_hand_sound, manual_left_leg_sound, manual_right_leg_sound, manual_tongue_sound]
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
# Note: No auto-loading of sounds to prevent playing all sounds on startup
|
| 1163 |
+
|
| 1164 |
+
return demo
|
| 1165 |
+
|
| 1166 |
+
if __name__ == "__main__":
|
| 1167 |
+
demo = create_interface()
|
| 1168 |
+
demo.launch(server_name="0.0.0.0", server_port=7867)
|
classifier.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EEG Motor Imagery Classifier Module
|
| 3 |
+
----------------------------------
|
| 4 |
+
Handles model loading, inference, and real-time prediction for motor imagery classification.
|
| 5 |
+
Based on the ShallowFBCSPNet architecture from the original eeg_motor_imagery.py script.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import numpy as np
|
| 11 |
+
from braindecode.models import ShallowFBCSPNet
|
| 12 |
+
from typing import Dict, Tuple, Optional
|
| 13 |
+
import os
|
| 14 |
+
from sklearn.metrics import accuracy_score
|
| 15 |
+
from data_processor import EEGDataProcessor
|
| 16 |
+
from config import DEMO_DATA_PATHS
|
| 17 |
+
|
| 18 |
+
class MotorImageryClassifier:
|
| 19 |
+
"""
|
| 20 |
+
Motor imagery classifier using ShallowFBCSPNet model.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model_path: str = "shallow_weights_all.pth"):
|
| 24 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 25 |
+
self.model = None
|
| 26 |
+
self.model_path = model_path
|
| 27 |
+
self.class_names = {
|
| 28 |
+
0: "left_hand",
|
| 29 |
+
1: "right_hand",
|
| 30 |
+
2: "neutral",
|
| 31 |
+
3: "left_leg",
|
| 32 |
+
4: "tongue",
|
| 33 |
+
5: "right_leg"
|
| 34 |
+
}
|
| 35 |
+
self.is_loaded = False
|
| 36 |
+
|
| 37 |
+
def load_model(self, n_chans: int, n_times: int, n_outputs: int = 6):
|
| 38 |
+
"""Load the pre-trained ShallowFBCSPNet model."""
|
| 39 |
+
try:
|
| 40 |
+
self.model = ShallowFBCSPNet(
|
| 41 |
+
n_chans=n_chans,
|
| 42 |
+
n_outputs=n_outputs,
|
| 43 |
+
n_times=n_times,
|
| 44 |
+
final_conv_length="auto"
|
| 45 |
+
).to(self.device)
|
| 46 |
+
|
| 47 |
+
if os.path.exists(self.model_path):
|
| 48 |
+
try:
|
| 49 |
+
state_dict = torch.load(self.model_path, map_location=self.device)
|
| 50 |
+
self.model.load_state_dict(state_dict)
|
| 51 |
+
self.model.eval()
|
| 52 |
+
self.is_loaded = True
|
| 53 |
+
print(f"β
Pre-trained model loaded successfully from {self.model_path}")
|
| 54 |
+
except Exception as model_error:
|
| 55 |
+
print(f"β οΈ Pre-trained model found but incompatible: {model_error}")
|
| 56 |
+
print("π Starting LOSO training with available EEG data...")
|
| 57 |
+
self.is_loaded = False
|
| 58 |
+
else:
|
| 59 |
+
print(f"β Pre-trained model weights not found at {self.model_path}")
|
| 60 |
+
print("π Starting LOSO training with available EEG data...")
|
| 61 |
+
self.is_loaded = False
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"β Error loading model: {e}")
|
| 65 |
+
print("π Starting LOSO training with available EEG data...")
|
| 66 |
+
self.is_loaded = False
|
| 67 |
+
|
| 68 |
+
def get_model_status(self) -> str:
|
| 69 |
+
"""Get current model status for user interface."""
|
| 70 |
+
if self.is_loaded:
|
| 71 |
+
return "β
Pre-trained model loaded and ready"
|
| 72 |
+
else:
|
| 73 |
+
return "π Using LOSO training (training new model from EEG data)"
|
| 74 |
+
|
| 75 |
+
def predict(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]:
|
| 76 |
+
"""
|
| 77 |
+
Predict motor imagery class from EEG data.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
eeg_data: EEG data array of shape (n_channels, n_times)
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
predicted_class: Predicted class index
|
| 84 |
+
confidence: Confidence score
|
| 85 |
+
probabilities: Dictionary of class probabilities
|
| 86 |
+
"""
|
| 87 |
+
if not self.is_loaded:
|
| 88 |
+
return self._fallback_loso_classification(eeg_data)
|
| 89 |
+
|
| 90 |
+
# Ensure input is the right shape: (batch, channels, time)
|
| 91 |
+
if eeg_data.ndim == 2:
|
| 92 |
+
eeg_data = eeg_data[np.newaxis, ...]
|
| 93 |
+
|
| 94 |
+
# Convert to tensor
|
| 95 |
+
x = torch.from_numpy(eeg_data.astype(np.float32)).to(self.device)
|
| 96 |
+
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
output = self.model(x)
|
| 99 |
+
probabilities = torch.softmax(output, dim=1)
|
| 100 |
+
predicted_class = torch.argmax(probabilities, dim=1).cpu().numpy()[0]
|
| 101 |
+
confidence = probabilities.max().cpu().numpy()
|
| 102 |
+
|
| 103 |
+
# Convert to dictionary
|
| 104 |
+
prob_dict = {
|
| 105 |
+
self.class_names[i]: probabilities[0, i].cpu().numpy()
|
| 106 |
+
for i in range(len(self.class_names))
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
return predicted_class, confidence, prob_dict
|
| 110 |
+
|
| 111 |
+
def _fallback_loso_classification(self, eeg_data: np.ndarray) -> Tuple[int, float, Dict[str, float]]:
|
| 112 |
+
"""
|
| 113 |
+
Fallback classification using LOSO (Leave-One-Session-Out) training.
|
| 114 |
+
Trains a model on available data when pre-trained model isn't available.
|
| 115 |
+
"""
|
| 116 |
+
try:
|
| 117 |
+
print("π No pre-trained model available. Training new model using LOSO method...")
|
| 118 |
+
print("β³ This may take a moment - training on real EEG data...")
|
| 119 |
+
|
| 120 |
+
# Initialize data processor
|
| 121 |
+
processor = EEGDataProcessor()
|
| 122 |
+
|
| 123 |
+
# Check if demo data files exist
|
| 124 |
+
available_files = [f for f in DEMO_DATA_PATHS if os.path.exists(f)]
|
| 125 |
+
if len(available_files) < 2:
|
| 126 |
+
raise ValueError(f"Not enough data files for LOSO training. Need at least 2 files, found {len(available_files)}. "
|
| 127 |
+
f"Available files: {available_files}")
|
| 128 |
+
|
| 129 |
+
# Perform LOSO split (using first session as test)
|
| 130 |
+
X_train, y_train, X_test, y_test, session_info = processor.prepare_loso_split(
|
| 131 |
+
available_files, test_session_idx=0
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Get data dimensions
|
| 135 |
+
n_chans = X_train.shape[1]
|
| 136 |
+
n_times = X_train.shape[2]
|
| 137 |
+
|
| 138 |
+
# Create and train model
|
| 139 |
+
self.model = ShallowFBCSPNet(
|
| 140 |
+
n_chans=n_chans,
|
| 141 |
+
n_outputs=6,
|
| 142 |
+
n_times=n_times,
|
| 143 |
+
final_conv_length="auto"
|
| 144 |
+
).to(self.device)
|
| 145 |
+
|
| 146 |
+
# Simple training loop
|
| 147 |
+
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
|
| 148 |
+
criterion = nn.CrossEntropyLoss()
|
| 149 |
+
|
| 150 |
+
# Convert training data to tensors
|
| 151 |
+
X_train_tensor = torch.from_numpy(X_train).float().to(self.device)
|
| 152 |
+
y_train_tensor = torch.from_numpy(y_train).long().to(self.device)
|
| 153 |
+
|
| 154 |
+
# Quick training (just a few epochs for demo)
|
| 155 |
+
self.model.train()
|
| 156 |
+
for epoch in range(50):
|
| 157 |
+
optimizer.zero_grad()
|
| 158 |
+
outputs = self.model(X_train_tensor)
|
| 159 |
+
loss = criterion(outputs, y_train_tensor)
|
| 160 |
+
loss.backward()
|
| 161 |
+
optimizer.step()
|
| 162 |
+
|
| 163 |
+
if epoch % 5 == 0:
|
| 164 |
+
print(f"LOSO Training - Epoch {epoch}, Loss: {loss.item():.4f}")
|
| 165 |
+
|
| 166 |
+
# Switch to evaluation mode
|
| 167 |
+
self.model.eval()
|
| 168 |
+
self.is_loaded = True
|
| 169 |
+
|
| 170 |
+
print("β
LOSO model trained successfully! Ready for classification.")
|
| 171 |
+
|
| 172 |
+
# Now make prediction with the trained model
|
| 173 |
+
return self.predict(eeg_data)
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Error in LOSO training: {e}")
|
| 177 |
+
raise RuntimeError(f"Failed to initialize classifier. Neither pre-trained model nor LOSO training succeeded: {e}")
|
| 178 |
+
|
| 179 |
+
|
config.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings for the EEG Motor Imagery Music Composer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
# Application settings
|
| 9 |
+
APP_NAME = "EEG Motor Imagery Music Composer"
|
| 10 |
+
VERSION = "1.0.0"
|
| 11 |
+
|
| 12 |
+
# Data paths
|
| 13 |
+
BASE_DIR = Path(__file__).parent
|
| 14 |
+
DATA_DIR = BASE_DIR / "data"
|
| 15 |
+
SOUND_DIR = BASE_DIR / "sounds"
|
| 16 |
+
MODEL_DIR = BASE_DIR
|
| 17 |
+
|
| 18 |
+
# Model settings
|
| 19 |
+
MODEL_PATH = MODEL_DIR / "shallow_weights_all.pth"
|
| 20 |
+
# Model architecture: Always uses ShallowFBCSPNet from braindecode
|
| 21 |
+
# If pre-trained weights not found, will train using LOSO on available data
|
| 22 |
+
|
| 23 |
+
# EEG Data settings
|
| 24 |
+
SAMPLING_RATE = 200 # Hz
|
| 25 |
+
EPOCH_DURATION = 1.5 # seconds
|
| 26 |
+
N_CHANNELS = 19 # without ground and reference 19 electrodes
|
| 27 |
+
N_CLASSES = 6 # or 4
|
| 28 |
+
|
| 29 |
+
# Classification settings
|
| 30 |
+
CONFIDENCE_THRESHOLD = 0.3 # Minimum confidence to add sound layer (lowered for testing)
|
| 31 |
+
MAX_COMPOSITION_LAYERS = 6 # Maximum layers in composition
|
| 32 |
+
|
| 33 |
+
# Sound settings
|
| 34 |
+
SOUND_MAPPING = {
|
| 35 |
+
"left_hand": "1_SoundHelix-Song-6_(Bass).wav",
|
| 36 |
+
"right_hand": "1_SoundHelix-Song-6_(Drums).wav",
|
| 37 |
+
"neutral": None, # No sound for neutral/rest state
|
| 38 |
+
"left_leg": "1_SoundHelix-Song-6_(Other).wav",
|
| 39 |
+
"tongue": "1_SoundHelix-Song-6_(Vocals).wav",
|
| 40 |
+
"right_leg": "1_SoundHelix-Song-6_(Bass).wav" # Can be remapped by user
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Motor imagery class names
|
| 44 |
+
CLASS_NAMES = {
|
| 45 |
+
0: "left_hand",
|
| 46 |
+
1: "right_hand",
|
| 47 |
+
2: "neutral",
|
| 48 |
+
3: "left_leg",
|
| 49 |
+
4: "tongue",
|
| 50 |
+
5: "right_leg"
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
CLASS_DESCRIPTIONS = {
|
| 54 |
+
"left_hand": "π€ Left Hand Movement",
|
| 55 |
+
"right_hand": "π€ Right Hand Movement",
|
| 56 |
+
"neutral": "π Neutral/Rest State",
|
| 57 |
+
"left_leg": "𦡠Left Leg Movement",
|
| 58 |
+
"tongue": "π
Tongue Movement",
|
| 59 |
+
"right_leg": "𦡠Right Leg Movement"
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Demo data paths (optional) - updated with available files
|
| 63 |
+
DEMO_DATA_PATHS = [
|
| 64 |
+
"data/raw_mat/HaLTSubjectA1602236StLRHandLegTongue.mat",
|
| 65 |
+
"data/raw_mat/HaLTSubjectA1603086StLRHandLegTongue.mat",
|
| 66 |
+
"data/raw_mat/HaLTSubjectA1603106StLRHandLegTongue.mat",
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
# Gradio settings
|
| 70 |
+
GRADIO_PORT = 7860
|
| 71 |
+
GRADIO_HOST = "0.0.0.0"
|
| 72 |
+
GRADIO_SHARE = False # Set to True to create public links
|
| 73 |
+
|
| 74 |
+
# Logging settings
|
| 75 |
+
LOG_LEVEL = "INFO"
|
| 76 |
+
LOG_FILE = BASE_DIR / "logs" / "app.log"
|
| 77 |
+
|
| 78 |
+
# Create necessary directories
|
| 79 |
+
def create_directories():
|
| 80 |
+
"""Create necessary directories if they don't exist."""
|
| 81 |
+
directories = [
|
| 82 |
+
DATA_DIR,
|
| 83 |
+
SOUND_DIR,
|
| 84 |
+
MODEL_DIR,
|
| 85 |
+
LOG_FILE.parent
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
for directory in directories:
|
| 89 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
create_directories()
|
| 93 |
+
print("Configuration directories created successfully!")
|
data_processor.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EEG Data Processing Module
|
| 3 |
+
-------------------------
|
| 4 |
+
Handles EEG data loading, preprocessing, and epoching for real-time classification.
|
| 5 |
+
Adapted from the original eeg_motor_imagery.py script.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import scipy.io
|
| 9 |
+
import numpy as np
|
| 10 |
+
import mne
|
| 11 |
+
import torch
|
| 12 |
+
import pandas as pd
|
| 13 |
+
from typing import List, Tuple, Dict, Optional
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from scipy.signal import butter, lfilter
|
| 16 |
+
|
| 17 |
+
class EEGDataProcessor:
|
| 18 |
+
"""
|
| 19 |
+
Processes EEG data from .mat files for motor imagery classification.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.fs = None
|
| 24 |
+
self.ch_names = None
|
| 25 |
+
self.event_id = {
|
| 26 |
+
"left_hand": 1,
|
| 27 |
+
"right_hand": 2,
|
| 28 |
+
"neutral": 3,
|
| 29 |
+
"left_leg": 4,
|
| 30 |
+
"tongue": 5,
|
| 31 |
+
"right_leg": 6,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def load_mat_file(self, file_path: str) -> Tuple[np.ndarray, np.ndarray, List[str], int]:
|
| 35 |
+
"""Load and parse a single .mat EEG file."""
|
| 36 |
+
mat = scipy.io.loadmat(file_path)
|
| 37 |
+
content = mat['o'][0, 0]
|
| 38 |
+
|
| 39 |
+
labels = content[4].flatten()
|
| 40 |
+
signals = content[5]
|
| 41 |
+
chan_names_raw = content[6]
|
| 42 |
+
channels = [ch[0][0] for ch in chan_names_raw]
|
| 43 |
+
fs = int(content[2][0, 0])
|
| 44 |
+
|
| 45 |
+
return signals, labels, channels, fs
|
| 46 |
+
|
| 47 |
+
def create_raw_object(self, signals: np.ndarray, channels: List[str], fs: int,
|
| 48 |
+
drop_ground_electrodes: bool = True) -> mne.io.RawArray:
|
| 49 |
+
"""Create MNE Raw object from signal data."""
|
| 50 |
+
df = pd.DataFrame(signals, columns=channels)
|
| 51 |
+
|
| 52 |
+
if drop_ground_electrodes:
|
| 53 |
+
# Drop auxiliary channels that should be excluded
|
| 54 |
+
aux_exclude = ('X3', 'X5')
|
| 55 |
+
columns_to_drop = [ch for ch in channels if ch in aux_exclude]
|
| 56 |
+
|
| 57 |
+
df = df.drop(columns=columns_to_drop, errors="ignore")
|
| 58 |
+
print(f"Dropped auxiliary channels {columns_to_drop}. Remaining channels: {len(df.columns)}")
|
| 59 |
+
|
| 60 |
+
eeg = df.values.T
|
| 61 |
+
ch_names = df.columns.tolist()
|
| 62 |
+
|
| 63 |
+
self.ch_names = ch_names
|
| 64 |
+
self.fs = fs
|
| 65 |
+
|
| 66 |
+
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types="eeg")
|
| 67 |
+
raw = mne.io.RawArray(eeg, info)
|
| 68 |
+
|
| 69 |
+
return raw
|
| 70 |
+
|
| 71 |
+
def extract_events(self, labels: np.ndarray) -> np.ndarray:
|
| 72 |
+
"""Extract events from label array."""
|
| 73 |
+
onsets = np.where((labels[1:] != 0) & (labels[:-1] == 0))[0] + 1
|
| 74 |
+
event_codes = labels[onsets].astype(int)
|
| 75 |
+
events = np.c_[onsets, np.zeros_like(onsets), event_codes]
|
| 76 |
+
|
| 77 |
+
# Keep only relevant events
|
| 78 |
+
mask = np.isin(events[:, 2], np.arange(1, 7))
|
| 79 |
+
events = events[mask]
|
| 80 |
+
|
| 81 |
+
return events
|
| 82 |
+
|
| 83 |
+
def create_epochs(self, raw: mne.io.RawArray, events: np.ndarray,
|
| 84 |
+
tmin: float = 0, tmax: float = 1.5) -> mne.Epochs:
|
| 85 |
+
"""Create epochs from raw data and events."""
|
| 86 |
+
epochs = mne.Epochs(
|
| 87 |
+
raw,
|
| 88 |
+
events=events,
|
| 89 |
+
event_id=self.event_id,
|
| 90 |
+
tmin=tmin,
|
| 91 |
+
tmax=tmax,
|
| 92 |
+
baseline=None,
|
| 93 |
+
preload=True,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return epochs
|
| 97 |
+
|
| 98 |
+
def process_files(self, file_paths: List[str]) -> Tuple[np.ndarray, np.ndarray]:
|
| 99 |
+
"""Process multiple EEG files and return combined data."""
|
| 100 |
+
all_epochs = []
|
| 101 |
+
|
| 102 |
+
for file_path in file_paths:
|
| 103 |
+
signals, labels, channels, fs = self.load_mat_file(file_path)
|
| 104 |
+
raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
|
| 105 |
+
events = self.extract_events(labels)
|
| 106 |
+
epochs = self.create_epochs(raw, events)
|
| 107 |
+
all_epochs.append(epochs)
|
| 108 |
+
|
| 109 |
+
if len(all_epochs) > 1:
|
| 110 |
+
epochs_combined = mne.concatenate_epochs(all_epochs)
|
| 111 |
+
else:
|
| 112 |
+
epochs_combined = all_epochs[0]
|
| 113 |
+
|
| 114 |
+
# Convert to arrays for model input
|
| 115 |
+
X = epochs_combined.get_data().astype("float32")
|
| 116 |
+
y = (epochs_combined.events[:, -1] - 1).astype("int64") # classes 0..5
|
| 117 |
+
|
| 118 |
+
return X, y
|
| 119 |
+
|
| 120 |
+
def load_continuous_data(self, file_paths: List[str]) -> Tuple[np.ndarray, int]:
|
| 121 |
+
"""
|
| 122 |
+
Load continuous raw EEG data without epoching.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
file_paths: List of .mat file paths
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
raw_data: Continuous EEG data [n_channels, n_timepoints]
|
| 129 |
+
fs: Sampling frequency
|
| 130 |
+
"""
|
| 131 |
+
all_raw_data = []
|
| 132 |
+
|
| 133 |
+
for file_path in file_paths:
|
| 134 |
+
signals, labels, channels, fs = self.load_mat_file(file_path)
|
| 135 |
+
raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
|
| 136 |
+
|
| 137 |
+
# Extract continuous data (no epoching)
|
| 138 |
+
continuous_data = raw.get_data() # [n_channels, n_timepoints]
|
| 139 |
+
all_raw_data.append(continuous_data)
|
| 140 |
+
|
| 141 |
+
# Concatenate all continuous data along time axis
|
| 142 |
+
if len(all_raw_data) > 1:
|
| 143 |
+
combined_raw = np.concatenate(all_raw_data, axis=1)
|
| 144 |
+
else:
|
| 145 |
+
combined_raw = all_raw_data[0]
|
| 146 |
+
|
| 147 |
+
return combined_raw, fs
|
| 148 |
+
|
| 149 |
+
def prepare_loso_split(self, file_paths: List[str], test_subject_idx: int = 0) -> Tuple:
|
| 150 |
+
"""
|
| 151 |
+
Prepare Leave-One-Subject-Out (LOSO) split for EEG data.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
file_paths: List of .mat file paths (one per subject)
|
| 155 |
+
test_subject_idx: Index of subject to use for testing
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
X_train, y_train, X_test, y_test, subject_info
|
| 159 |
+
"""
|
| 160 |
+
all_subjects_data = []
|
| 161 |
+
subject_info = []
|
| 162 |
+
|
| 163 |
+
# Load each subject separately
|
| 164 |
+
for i, file_path in enumerate(file_paths):
|
| 165 |
+
signals, labels, channels, fs = self.load_mat_file(file_path)
|
| 166 |
+
raw = self.create_raw_object(signals, channels, fs, drop_ground_electrodes=True)
|
| 167 |
+
events = self.extract_events(labels)
|
| 168 |
+
epochs = self.create_epochs(raw, events)
|
| 169 |
+
|
| 170 |
+
# Convert to arrays
|
| 171 |
+
X_subject = epochs.get_data().astype("float32")
|
| 172 |
+
y_subject = (epochs.events[:, -1] - 1).astype("int64")
|
| 173 |
+
|
| 174 |
+
all_subjects_data.append((X_subject, y_subject))
|
| 175 |
+
subject_info.append({
|
| 176 |
+
'file_path': file_path,
|
| 177 |
+
'subject_id': f"Subject_{i+1}",
|
| 178 |
+
'n_epochs': len(X_subject),
|
| 179 |
+
'channels': channels,
|
| 180 |
+
'fs': fs
|
| 181 |
+
})
|
| 182 |
+
|
| 183 |
+
# LOSO split: one subject for test, others for train
|
| 184 |
+
test_subject = all_subjects_data[test_subject_idx]
|
| 185 |
+
train_subjects = [all_subjects_data[i] for i in range(len(all_subjects_data)) if i != test_subject_idx]
|
| 186 |
+
|
| 187 |
+
# Combine training subjects
|
| 188 |
+
if len(train_subjects) > 1:
|
| 189 |
+
X_train = np.concatenate([subj[0] for subj in train_subjects], axis=0)
|
| 190 |
+
y_train = np.concatenate([subj[1] for subj in train_subjects], axis=0)
|
| 191 |
+
else:
|
| 192 |
+
X_train, y_train = train_subjects[0]
|
| 193 |
+
|
| 194 |
+
X_test, y_test = test_subject
|
| 195 |
+
|
| 196 |
+
print("LOSO Split:")
|
| 197 |
+
print(f" Test Subject: {subject_info[test_subject_idx]['subject_id']} ({len(X_test)} epochs)")
|
| 198 |
+
print(f" Train Subjects: {len(train_subjects)} subjects ({len(X_train)} epochs)")
|
| 199 |
+
|
| 200 |
+
return X_train, y_train, X_test, y_test, subject_info
|
| 201 |
+
|
| 202 |
+
def simulate_real_time_data(self, X: np.ndarray, y: np.ndarray, mode: str = "random") -> Tuple[np.ndarray, int]:
|
| 203 |
+
"""
|
| 204 |
+
Simulate real-time EEG data for demo purposes.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
X: EEG data array (currently epoched data)
|
| 208 |
+
y: Labels array
|
| 209 |
+
mode: "random", "sequential", or "class_balanced"
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
Single epoch and its true label
|
| 213 |
+
"""
|
| 214 |
+
if mode == "random":
|
| 215 |
+
idx = np.random.randint(0, len(X))
|
| 216 |
+
elif mode == "sequential":
|
| 217 |
+
# Use a counter for sequential sampling (would need to store state)
|
| 218 |
+
idx = np.random.randint(0, len(X)) # Simplified for now
|
| 219 |
+
elif mode == "class_balanced":
|
| 220 |
+
# Sample ensuring we get different classes
|
| 221 |
+
available_classes = np.unique(y)
|
| 222 |
+
target_class = np.random.choice(available_classes)
|
| 223 |
+
class_indices = np.where(y == target_class)[0]
|
| 224 |
+
idx = np.random.choice(class_indices)
|
| 225 |
+
else:
|
| 226 |
+
idx = np.random.randint(0, len(X))
|
| 227 |
+
|
| 228 |
+
return X[idx], y[idx]
|
| 229 |
+
|
| 230 |
+
def simulate_continuous_stream(self, raw_data: np.ndarray, fs: int, window_size: float = 1.5) -> np.ndarray:
|
| 231 |
+
"""
|
| 232 |
+
Simulate continuous EEG stream by extracting sliding windows from raw data.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
raw_data: Continuous EEG data [n_channels, n_timepoints]
|
| 236 |
+
fs: Sampling frequency
|
| 237 |
+
window_size: Window size in seconds
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
Single window of EEG data [n_channels, window_samples]
|
| 241 |
+
"""
|
| 242 |
+
window_samples = int(window_size * fs) # e.g., 1.5s * 200Hz = 300 samples
|
| 243 |
+
|
| 244 |
+
# Ensure we don't go beyond the data
|
| 245 |
+
max_start = raw_data.shape[1] - window_samples
|
| 246 |
+
if max_start <= 0:
|
| 247 |
+
return raw_data # Return full data if too short
|
| 248 |
+
|
| 249 |
+
# Random starting point in the continuous stream
|
| 250 |
+
start_idx = np.random.randint(0, max_start)
|
| 251 |
+
end_idx = start_idx + window_samples
|
| 252 |
+
|
| 253 |
+
# Extract window
|
| 254 |
+
window = raw_data[:, start_idx:end_idx]
|
| 255 |
+
|
| 256 |
+
return window
|
| 257 |
+
|
demo.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Demo script for the EEG Motor Imagery Music Composer
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Add the current directory to Python path
|
| 11 |
+
current_dir = Path(__file__).parent
|
| 12 |
+
sys.path.insert(0, str(current_dir))
|
| 13 |
+
|
| 14 |
+
from data_processor import EEGDataProcessor
|
| 15 |
+
from classifier import MotorImageryClassifier
|
| 16 |
+
from sound_library import SoundManager
|
| 17 |
+
from utils import setup_logging, create_classification_summary
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
def run_demo():
|
| 21 |
+
"""Run a simple demo of the system components."""
|
| 22 |
+
print("π§ EEG Motor Imagery Music Composer - Demo")
|
| 23 |
+
print("=" * 50)
|
| 24 |
+
|
| 25 |
+
# Initialize components
|
| 26 |
+
print("Initializing components...")
|
| 27 |
+
data_processor = EEGDataProcessor()
|
| 28 |
+
classifier = MotorImageryClassifier()
|
| 29 |
+
sound_manager = SoundManager()
|
| 30 |
+
|
| 31 |
+
# Create mock data for demo
|
| 32 |
+
print("Creating mock EEG data...")
|
| 33 |
+
mock_data = np.random.randn(10, 32, 384).astype(np.float32) # 10 samples, 32 channels, 384 time points
|
| 34 |
+
mock_labels = np.random.randint(0, 6, 10)
|
| 35 |
+
|
| 36 |
+
# Initialize classifier
|
| 37 |
+
print("Loading classifier...")
|
| 38 |
+
classifier.load_model(n_chans=32, n_times=384)
|
| 39 |
+
|
| 40 |
+
print(f"Available sounds: {sound_manager.get_available_sounds()}")
|
| 41 |
+
print()
|
| 42 |
+
|
| 43 |
+
# Run classification demo
|
| 44 |
+
print("Running classification demo...")
|
| 45 |
+
print("-" * 30)
|
| 46 |
+
|
| 47 |
+
for i in range(5):
|
| 48 |
+
# Get random sample
|
| 49 |
+
sample_idx = np.random.randint(0, len(mock_data))
|
| 50 |
+
eeg_sample = mock_data[sample_idx]
|
| 51 |
+
true_label = mock_labels[sample_idx]
|
| 52 |
+
|
| 53 |
+
# Classify
|
| 54 |
+
predicted_class, confidence, probabilities = classifier.predict(eeg_sample)
|
| 55 |
+
predicted_name = classifier.class_names[predicted_class]
|
| 56 |
+
true_name = classifier.class_names[true_label]
|
| 57 |
+
|
| 58 |
+
# Create summary
|
| 59 |
+
summary = create_classification_summary(predicted_name, confidence, probabilities)
|
| 60 |
+
|
| 61 |
+
print(f"Sample {i+1}:")
|
| 62 |
+
print(f" True class: {true_name}")
|
| 63 |
+
print(f" Predicted: {summary['emoji']} {predicted_name} ({summary['confidence_percent']})")
|
| 64 |
+
|
| 65 |
+
# Add to composition if confidence is high
|
| 66 |
+
if confidence > 0.7 and predicted_name != 'neutral':
|
| 67 |
+
sound_manager.add_layer(predicted_name, confidence)
|
| 68 |
+
print(f" βͺ Added to composition: {predicted_name}")
|
| 69 |
+
else:
|
| 70 |
+
print(f" - Not added (low confidence or neutral)")
|
| 71 |
+
|
| 72 |
+
print()
|
| 73 |
+
|
| 74 |
+
# Show composition summary
|
| 75 |
+
composition_info = sound_manager.get_composition_info()
|
| 76 |
+
print("Final Composition:")
|
| 77 |
+
print("-" * 20)
|
| 78 |
+
if composition_info:
|
| 79 |
+
for i, layer in enumerate(composition_info, 1):
|
| 80 |
+
print(f"{i}. {layer['class']} (confidence: {layer['confidence']:.2f})")
|
| 81 |
+
else:
|
| 82 |
+
print("No composition layers created")
|
| 83 |
+
|
| 84 |
+
print()
|
| 85 |
+
print("Demo completed! π΅")
|
| 86 |
+
print("To run the full Gradio interface, execute: python app.py")
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
try:
|
| 90 |
+
run_demo()
|
| 91 |
+
except KeyboardInterrupt:
|
| 92 |
+
print("\nDemo interrupted by user")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"Error running demo: {e}")
|
| 95 |
+
import traceback
|
| 96 |
+
traceback.print_exc()
|
enhanced_utils.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced utilities incorporating useful functions from the original src/ templates
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
|
| 10 |
+
from typing import Dict, List, Tuple, Optional
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
def evaluate_model_performance(model, dataloader, device, class_names: List[str]) -> Dict:
|
| 14 |
+
"""
|
| 15 |
+
Comprehensive model evaluation with metrics and visualizations.
|
| 16 |
+
Enhanced version of src/evaluate.py
|
| 17 |
+
"""
|
| 18 |
+
model.eval()
|
| 19 |
+
all_preds = []
|
| 20 |
+
all_labels = []
|
| 21 |
+
all_probs = []
|
| 22 |
+
|
| 23 |
+
with torch.no_grad():
|
| 24 |
+
for inputs, labels in dataloader:
|
| 25 |
+
inputs = inputs.to(device)
|
| 26 |
+
labels = labels.to(device)
|
| 27 |
+
|
| 28 |
+
outputs = model(inputs)
|
| 29 |
+
probs = torch.softmax(outputs, dim=1)
|
| 30 |
+
_, preds = torch.max(outputs, 1)
|
| 31 |
+
|
| 32 |
+
all_preds.extend(preds.cpu().numpy())
|
| 33 |
+
all_labels.extend(labels.cpu().numpy())
|
| 34 |
+
all_probs.extend(probs.cpu().numpy())
|
| 35 |
+
|
| 36 |
+
# Calculate metrics
|
| 37 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 38 |
+
cm = confusion_matrix(all_labels, all_preds)
|
| 39 |
+
report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
|
| 40 |
+
|
| 41 |
+
return {
|
| 42 |
+
'accuracy': accuracy,
|
| 43 |
+
'predictions': all_preds,
|
| 44 |
+
'labels': all_labels,
|
| 45 |
+
'probabilities': all_probs,
|
| 46 |
+
'confusion_matrix': cm,
|
| 47 |
+
'classification_report': report
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], title: str = "Confusion Matrix") -> plt.Figure:
|
| 51 |
+
"""
|
| 52 |
+
Plot confusion matrix with proper formatting.
|
| 53 |
+
Enhanced version from src/visualize.py
|
| 54 |
+
"""
|
| 55 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 56 |
+
|
| 57 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
| 58 |
+
xticklabels=class_names, yticklabels=class_names, ax=ax)
|
| 59 |
+
|
| 60 |
+
ax.set_title(title)
|
| 61 |
+
ax.set_ylabel('True Label')
|
| 62 |
+
ax.set_xlabel('Predicted Label')
|
| 63 |
+
|
| 64 |
+
plt.tight_layout()
|
| 65 |
+
return fig
|
| 66 |
+
|
| 67 |
+
def plot_classification_probabilities(probabilities: np.ndarray, class_names: List[str],
|
| 68 |
+
sample_indices: Optional[List[int]] = None) -> plt.Figure:
|
| 69 |
+
"""
|
| 70 |
+
Plot classification probabilities for selected samples.
|
| 71 |
+
"""
|
| 72 |
+
if sample_indices is None:
|
| 73 |
+
sample_indices = list(range(min(10, len(probabilities))))
|
| 74 |
+
|
| 75 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 76 |
+
|
| 77 |
+
x = np.arange(len(class_names))
|
| 78 |
+
width = 0.8 / len(sample_indices)
|
| 79 |
+
|
| 80 |
+
for i, sample_idx in enumerate(sample_indices):
|
| 81 |
+
offset = (i - len(sample_indices)/2) * width
|
| 82 |
+
ax.bar(x + offset, probabilities[sample_idx], width,
|
| 83 |
+
label=f'Sample {sample_idx}', alpha=0.8)
|
| 84 |
+
|
| 85 |
+
ax.set_xlabel('Motor Imagery Classes')
|
| 86 |
+
ax.set_ylabel('Probability')
|
| 87 |
+
ax.set_title('Classification Probabilities')
|
| 88 |
+
ax.set_xticks(x)
|
| 89 |
+
ax.set_xticklabels(class_names, rotation=45)
|
| 90 |
+
ax.legend()
|
| 91 |
+
ax.grid(True, alpha=0.3)
|
| 92 |
+
|
| 93 |
+
plt.tight_layout()
|
| 94 |
+
return fig
|
| 95 |
+
|
| 96 |
+
def plot_training_history(history: Dict[str, List[float]]) -> plt.Figure:
|
| 97 |
+
"""
|
| 98 |
+
Plot training history (loss and accuracy).
|
| 99 |
+
Enhanced version from src/visualize.py
|
| 100 |
+
"""
|
| 101 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
| 102 |
+
|
| 103 |
+
# Plot accuracy
|
| 104 |
+
if 'train_accuracy' in history and 'val_accuracy' in history:
|
| 105 |
+
ax1.plot(history['train_accuracy'], label='Train', linewidth=2)
|
| 106 |
+
ax1.plot(history['val_accuracy'], label='Validation', linewidth=2)
|
| 107 |
+
ax1.set_title('Model Accuracy')
|
| 108 |
+
ax1.set_xlabel('Epoch')
|
| 109 |
+
ax1.set_ylabel('Accuracy')
|
| 110 |
+
ax1.legend()
|
| 111 |
+
ax1.grid(True, alpha=0.3)
|
| 112 |
+
|
| 113 |
+
# Plot loss
|
| 114 |
+
if 'train_loss' in history and 'val_loss' in history:
|
| 115 |
+
ax2.plot(history['train_loss'], label='Train', linewidth=2)
|
| 116 |
+
ax2.plot(history['val_loss'], label='Validation', linewidth=2)
|
| 117 |
+
ax2.set_title('Model Loss')
|
| 118 |
+
ax2.set_xlabel('Epoch')
|
| 119 |
+
ax2.set_ylabel('Loss')
|
| 120 |
+
ax2.legend()
|
| 121 |
+
ax2.grid(True, alpha=0.3)
|
| 122 |
+
|
| 123 |
+
plt.tight_layout()
|
| 124 |
+
return fig
|
| 125 |
+
|
| 126 |
+
def plot_eeg_channels(eeg_data: np.ndarray, channel_names: Optional[List[str]] = None,
|
| 127 |
+
sample_rate: int = 256, title: str = "EEG Channels") -> plt.Figure:
|
| 128 |
+
"""
|
| 129 |
+
Plot multiple EEG channels.
|
| 130 |
+
Enhanced visualization for EEG data.
|
| 131 |
+
"""
|
| 132 |
+
n_channels, n_samples = eeg_data.shape
|
| 133 |
+
time_axis = np.arange(n_samples) / sample_rate
|
| 134 |
+
|
| 135 |
+
# Determine subplot layout
|
| 136 |
+
n_rows = int(np.ceil(np.sqrt(n_channels)))
|
| 137 |
+
n_cols = int(np.ceil(n_channels / n_rows))
|
| 138 |
+
|
| 139 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10))
|
| 140 |
+
if n_channels == 1:
|
| 141 |
+
axes = [axes]
|
| 142 |
+
else:
|
| 143 |
+
axes = axes.flatten()
|
| 144 |
+
|
| 145 |
+
for i in range(n_channels):
|
| 146 |
+
ax = axes[i]
|
| 147 |
+
ax.plot(time_axis, eeg_data[i], 'b-', linewidth=1)
|
| 148 |
+
|
| 149 |
+
channel_name = channel_names[i] if channel_names else f'Channel {i+1}'
|
| 150 |
+
ax.set_title(channel_name)
|
| 151 |
+
ax.set_xlabel('Time (s)')
|
| 152 |
+
ax.set_ylabel('Amplitude')
|
| 153 |
+
ax.grid(True, alpha=0.3)
|
| 154 |
+
|
| 155 |
+
# Hide unused subplots
|
| 156 |
+
for i in range(n_channels, len(axes)):
|
| 157 |
+
axes[i].set_visible(False)
|
| 158 |
+
|
| 159 |
+
plt.suptitle(title)
|
| 160 |
+
plt.tight_layout()
|
| 161 |
+
return fig
|
| 162 |
+
|
| 163 |
+
class EarlyStopping:
|
| 164 |
+
"""
|
| 165 |
+
Early stopping utility from src/types/index.py
|
| 166 |
+
"""
|
| 167 |
+
def __init__(self, patience=7, min_delta=0, restore_best_weights=True, verbose=False):
|
| 168 |
+
self.patience = patience
|
| 169 |
+
self.min_delta = min_delta
|
| 170 |
+
self.restore_best_weights = restore_best_weights
|
| 171 |
+
self.verbose = verbose
|
| 172 |
+
self.best_loss = None
|
| 173 |
+
self.counter = 0
|
| 174 |
+
self.best_weights = None
|
| 175 |
+
|
| 176 |
+
def __call__(self, val_loss, model):
|
| 177 |
+
if self.best_loss is None:
|
| 178 |
+
self.best_loss = val_loss
|
| 179 |
+
self.save_checkpoint(model)
|
| 180 |
+
elif val_loss < self.best_loss - self.min_delta:
|
| 181 |
+
self.best_loss = val_loss
|
| 182 |
+
self.counter = 0
|
| 183 |
+
self.save_checkpoint(model)
|
| 184 |
+
else:
|
| 185 |
+
self.counter += 1
|
| 186 |
+
|
| 187 |
+
if self.counter >= self.patience:
|
| 188 |
+
if self.verbose:
|
| 189 |
+
print(f'Early stopping triggered after {self.counter} epochs of no improvement')
|
| 190 |
+
if self.restore_best_weights:
|
| 191 |
+
model.load_state_dict(self.best_weights)
|
| 192 |
+
return True
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
+
def save_checkpoint(self, model):
|
| 196 |
+
"""Save model when validation loss decreases."""
|
| 197 |
+
if self.restore_best_weights:
|
| 198 |
+
self.best_weights = model.state_dict().copy()
|
| 199 |
+
|
| 200 |
+
def create_enhanced_evaluation_report(model, test_loader, class_names: List[str],
|
| 201 |
+
device, save_plots: bool = True) -> Dict:
|
| 202 |
+
"""
|
| 203 |
+
Create a comprehensive evaluation report with plots and metrics.
|
| 204 |
+
"""
|
| 205 |
+
# Get evaluation results
|
| 206 |
+
results = evaluate_model_performance(model, test_loader, device, class_names)
|
| 207 |
+
|
| 208 |
+
# Create visualizations
|
| 209 |
+
plots = {}
|
| 210 |
+
|
| 211 |
+
# Confusion Matrix
|
| 212 |
+
plots['confusion_matrix'] = plot_confusion_matrix(
|
| 213 |
+
results['confusion_matrix'], class_names,
|
| 214 |
+
title="Motor Imagery Classification - Confusion Matrix"
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Classification Probabilities (sample)
|
| 218 |
+
plots['probabilities'] = plot_classification_probabilities(
|
| 219 |
+
np.array(results['probabilities']), class_names,
|
| 220 |
+
sample_indices=list(range(min(5, len(results['probabilities']))))
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if save_plots:
|
| 224 |
+
for plot_name, fig in plots.items():
|
| 225 |
+
fig.savefig(f'{plot_name}.png', dpi=300, bbox_inches='tight')
|
| 226 |
+
|
| 227 |
+
return {
|
| 228 |
+
'metrics': results,
|
| 229 |
+
'plots': plots,
|
| 230 |
+
'summary': {
|
| 231 |
+
'accuracy': results['accuracy'],
|
| 232 |
+
'n_samples': len(results['labels']),
|
| 233 |
+
'n_classes': len(class_names),
|
| 234 |
+
'class_names': class_names
|
| 235 |
+
}
|
| 236 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
pandas
|
| 3 |
+
numpy
|
| 4 |
+
scipy
|
| 5 |
+
matplotlib
|
| 6 |
+
torch
|
| 7 |
+
torchvision
|
| 8 |
+
mne
|
| 9 |
+
braindecode
|
| 10 |
+
scikit-learn
|
| 11 |
+
pydub
|
| 12 |
+
soundfile
|
| 13 |
+
librosa
|
| 14 |
+
threading
|
sound_library.py
ADDED
|
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sound Management System for EEG Motor Imagery Classification
|
| 3 |
+
-----------------------------------------------------------
|
| 4 |
+
Handles sound mapping, layering, and music composition based on motor imagery predictions.
|
| 5 |
+
Uses local SoundHelix audio files for different motor imagery classes.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import os
|
| 11 |
+
from typing import Dict, List, Optional, Tuple
|
| 12 |
+
import threading
|
| 13 |
+
import time
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import tempfile
|
| 16 |
+
|
| 17 |
+
# Professional audio effects processor using scipy and librosa
|
| 18 |
+
from scipy import signal
|
| 19 |
+
import librosa
|
| 20 |
+
|
| 21 |
+
class AudioEffectsProcessor:
|
| 22 |
+
"""Professional audio effects for DJ mode using scipy and librosa."""
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def apply_volume_fade(data: np.ndarray, fade_type: str = "out", fade_length: float = 0.5) -> np.ndarray:
|
| 26 |
+
"""Apply volume fade effect with linear fade in/out."""
|
| 27 |
+
try:
|
| 28 |
+
samples = len(data)
|
| 29 |
+
fade_samples = int(fade_length * samples)
|
| 30 |
+
|
| 31 |
+
if fade_type == "out":
|
| 32 |
+
# Fade out: linear decrease from 1.0 to 0.3
|
| 33 |
+
fade_curve = np.linspace(1.0, 0.3, fade_samples)
|
| 34 |
+
data[-fade_samples:] *= fade_curve
|
| 35 |
+
elif fade_type == "in":
|
| 36 |
+
# Fade in: linear increase from 0.3 to 1.0
|
| 37 |
+
fade_curve = np.linspace(0.3, 1.0, fade_samples)
|
| 38 |
+
data[:fade_samples] *= fade_curve
|
| 39 |
+
|
| 40 |
+
return data
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Volume fade effect failed: {e}")
|
| 43 |
+
return data
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def apply_high_pass_filter(data: np.ndarray, samplerate: int, cutoff: float = 800.0) -> np.ndarray:
|
| 47 |
+
"""Apply high-pass filter to emphasize highs and cut lows."""
|
| 48 |
+
try:
|
| 49 |
+
# Design butterworth high-pass filter
|
| 50 |
+
nyquist = samplerate / 2
|
| 51 |
+
normalized_cutoff = cutoff / nyquist
|
| 52 |
+
b, a = signal.butter(4, normalized_cutoff, btype='high', analog=False)
|
| 53 |
+
|
| 54 |
+
# Apply filter
|
| 55 |
+
filtered_data = signal.filtfilt(b, a, data)
|
| 56 |
+
return filtered_data
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"High-pass filter failed: {e}")
|
| 59 |
+
return data
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def apply_low_pass_filter(data: np.ndarray, samplerate: int, cutoff: float = 1200.0) -> np.ndarray:
|
| 63 |
+
"""Apply low-pass filter to emphasize lows and cut highs."""
|
| 64 |
+
try:
|
| 65 |
+
# Design butterworth low-pass filter
|
| 66 |
+
nyquist = samplerate / 2
|
| 67 |
+
normalized_cutoff = cutoff / nyquist
|
| 68 |
+
b, a = signal.butter(4, normalized_cutoff, btype='low', analog=False)
|
| 69 |
+
|
| 70 |
+
# Apply filter
|
| 71 |
+
filtered_data = signal.filtfilt(b, a, data)
|
| 72 |
+
return filtered_data
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Low-pass filter failed: {e}")
|
| 75 |
+
return data
|
| 76 |
+
|
| 77 |
+
@staticmethod
|
| 78 |
+
def apply_reverb(data: np.ndarray, samplerate: int, room_size: float = 0.5) -> np.ndarray:
|
| 79 |
+
"""Apply simple reverb effect using delay and feedback."""
|
| 80 |
+
try:
|
| 81 |
+
# Simple reverb using multiple delayed copies
|
| 82 |
+
delay_samples = int(0.1 * samplerate) # 100ms delay
|
| 83 |
+
decay = 0.3 * room_size
|
| 84 |
+
|
| 85 |
+
# Create reverb buffer
|
| 86 |
+
reverb_data = np.copy(data)
|
| 87 |
+
|
| 88 |
+
# Add delayed copies with decay
|
| 89 |
+
for i in range(3):
|
| 90 |
+
delay = delay_samples * (i + 1)
|
| 91 |
+
if delay < len(data):
|
| 92 |
+
gain = decay ** (i + 1)
|
| 93 |
+
reverb_data[delay:] += data[:-delay] * gain
|
| 94 |
+
|
| 95 |
+
# Mix original with reverb
|
| 96 |
+
return 0.7 * data + 0.3 * reverb_data
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print(f"Reverb effect failed: {e}")
|
| 99 |
+
return data
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def apply_bass_boost(data: np.ndarray, samplerate: int, boost_db: float = 6.0) -> np.ndarray:
|
| 103 |
+
"""Apply bass boost using low-frequency shelving filter."""
|
| 104 |
+
try:
|
| 105 |
+
# Design low-shelf filter for bass boost
|
| 106 |
+
freq = 250.0 # Bass frequency cutoff
|
| 107 |
+
nyquist = samplerate / 2
|
| 108 |
+
normalized_freq = freq / nyquist
|
| 109 |
+
|
| 110 |
+
# Convert dB to linear gain
|
| 111 |
+
gain = 10 ** (boost_db / 20)
|
| 112 |
+
|
| 113 |
+
# Simple bass boost: amplify low frequencies
|
| 114 |
+
b, a = signal.butter(2, normalized_freq, btype='low', analog=False)
|
| 115 |
+
low_freq = signal.filtfilt(b, a, data)
|
| 116 |
+
|
| 117 |
+
# Mix boosted lows with original
|
| 118 |
+
boosted_data = data + (low_freq * (gain - 1))
|
| 119 |
+
|
| 120 |
+
# Normalize to prevent clipping
|
| 121 |
+
max_val = np.max(np.abs(boosted_data))
|
| 122 |
+
if max_val > 0.95:
|
| 123 |
+
boosted_data = boosted_data * 0.95 / max_val
|
| 124 |
+
|
| 125 |
+
return boosted_data
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"Bass boost failed: {e}")
|
| 128 |
+
return data
|
| 129 |
+
|
| 130 |
+
@staticmethod
|
| 131 |
+
def process_with_effects(audio_file: str, active_effects: Dict[str, bool]) -> str:
|
| 132 |
+
"""Process audio file with active effects and return processed version."""
|
| 133 |
+
try:
|
| 134 |
+
# Check if audio_file is valid
|
| 135 |
+
if not audio_file or not os.path.exists(audio_file):
|
| 136 |
+
print(f"Invalid audio file: {audio_file}")
|
| 137 |
+
return audio_file
|
| 138 |
+
|
| 139 |
+
# Read audio file
|
| 140 |
+
data, samplerate = sf.read(audio_file)
|
| 141 |
+
|
| 142 |
+
# Handle stereo to mono conversion if needed
|
| 143 |
+
if len(data.shape) > 1:
|
| 144 |
+
data = np.mean(data, axis=1)
|
| 145 |
+
|
| 146 |
+
# Apply effects based on active states
|
| 147 |
+
processed_data = np.copy(data)
|
| 148 |
+
effect_names = []
|
| 149 |
+
|
| 150 |
+
if active_effects.get("left_hand", False): # Volume Fade
|
| 151 |
+
processed_data = AudioEffectsProcessor.apply_volume_fade(processed_data, "out")
|
| 152 |
+
effect_names.append("fade")
|
| 153 |
+
|
| 154 |
+
if active_effects.get("right_hand", False): # High Pass Filter
|
| 155 |
+
processed_data = AudioEffectsProcessor.apply_high_pass_filter(processed_data, samplerate)
|
| 156 |
+
effect_names.append("hpf")
|
| 157 |
+
|
| 158 |
+
if active_effects.get("left_leg", False): # Reverb
|
| 159 |
+
processed_data = AudioEffectsProcessor.apply_reverb(processed_data, samplerate)
|
| 160 |
+
effect_names.append("rev")
|
| 161 |
+
|
| 162 |
+
if active_effects.get("right_leg", False): # Low Pass Filter
|
| 163 |
+
processed_data = AudioEffectsProcessor.apply_low_pass_filter(processed_data, samplerate)
|
| 164 |
+
effect_names.append("lpf")
|
| 165 |
+
|
| 166 |
+
if active_effects.get("tongue", False): # Bass Boost
|
| 167 |
+
processed_data = AudioEffectsProcessor.apply_bass_boost(processed_data, samplerate)
|
| 168 |
+
effect_names.append("bass")
|
| 169 |
+
|
| 170 |
+
# Create unique filename based on active effects
|
| 171 |
+
base_name = os.path.splitext(audio_file)[0]
|
| 172 |
+
effects_suffix = "_".join(effect_names) if effect_names else "clean"
|
| 173 |
+
processed_file = f"{base_name}_fx_{effects_suffix}.wav"
|
| 174 |
+
|
| 175 |
+
# Save processed audio
|
| 176 |
+
# sf.write(processed_file, processed_data, samplerate)
|
| 177 |
+
print(f"ποΈ Audio processed with effects: {effects_suffix} (FILE SAVING DISABLED)")
|
| 178 |
+
|
| 179 |
+
# Return absolute path (using original file since saving is disabled)
|
| 180 |
+
return os.path.abspath(audio_file)
|
| 181 |
+
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"Audio processing failed: {e}")
|
| 184 |
+
return os.path.abspath(audio_file) if audio_file else None
|
| 185 |
+
|
| 186 |
+
class SoundManager:
|
| 187 |
+
"""
|
| 188 |
+
Manages cyclic sound composition for motor imagery classification.
|
| 189 |
+
Supports full-cycle composition with user-customizable movement-sound mappings.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, sound_dir: str = "sounds", include_neutral_in_cycle: bool = False):
|
| 193 |
+
self.sound_dir = Path(sound_dir)
|
| 194 |
+
self.include_neutral_in_cycle = include_neutral_in_cycle
|
| 195 |
+
|
| 196 |
+
# Composition state
|
| 197 |
+
self.composition_layers = [] # All layers across all cycles
|
| 198 |
+
self.current_cycle = 0
|
| 199 |
+
self.current_step = 0 # Current step within cycle (0-5)
|
| 200 |
+
self.cycle_complete = False
|
| 201 |
+
self.completed_cycles = 0 # Track completed cycles for session management
|
| 202 |
+
self.max_cycles = 2 # Rehabilitation session limit
|
| 203 |
+
|
| 204 |
+
# DJ Effects phase management
|
| 205 |
+
self.current_phase = "building" # "building" or "dj_effects"
|
| 206 |
+
self.mixed_composition_file = None # Path to current mixed composition
|
| 207 |
+
self.active_effects = { # Track which effects are currently active
|
| 208 |
+
"left_hand": False, # Volume fade
|
| 209 |
+
"right_hand": False, # Filter sweep
|
| 210 |
+
"left_leg": False, # Reverb
|
| 211 |
+
"right_leg": False, # Tempo modulation
|
| 212 |
+
"tongue": False # Bass boost
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
# All possible movements (neutral is optional for composition)
|
| 216 |
+
self.all_movements = ["left_hand", "right_hand", "neutral", "left_leg", "tongue", "right_leg"]
|
| 217 |
+
|
| 218 |
+
# Active movements that contribute to composition (excluding neutral)
|
| 219 |
+
self.active_movements = ["left_hand", "right_hand", "left_leg", "tongue", "right_leg"]
|
| 220 |
+
|
| 221 |
+
# Current cycle's random movement sequence (shuffled each cycle)
|
| 222 |
+
self.current_movement_sequence = []
|
| 223 |
+
self.movements_completed = set() # Track which movements have been successfully completed
|
| 224 |
+
self._generate_new_sequence()
|
| 225 |
+
|
| 226 |
+
# User-customizable sound mapping (can be updated after each cycle)
|
| 227 |
+
self.current_sound_mapping = {
|
| 228 |
+
"left_hand": "1_SoundHelix-Song-6_(Bass).wav",
|
| 229 |
+
"right_hand": "1_SoundHelix-Song-6_(Drums).wav",
|
| 230 |
+
"neutral": None, # No sound for neutral/rest state
|
| 231 |
+
"left_leg": "1_SoundHelix-Song-6_(Other).wav",
|
| 232 |
+
"tongue": "1_SoundHelix-Song-6_(Vocals).wav",
|
| 233 |
+
"right_leg": "1_SoundHelix-Song-6_(Bass).wav" # Can be remapped
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
# Available sound files
|
| 237 |
+
self.available_sounds = [
|
| 238 |
+
"1_SoundHelix-Song-6_(Bass).wav",
|
| 239 |
+
"1_SoundHelix-Song-6_(Drums).wav",
|
| 240 |
+
"1_SoundHelix-Song-6_(Other).wav",
|
| 241 |
+
"1_SoundHelix-Song-6_(Vocals).wav"
|
| 242 |
+
]
|
| 243 |
+
|
| 244 |
+
# Load sound files
|
| 245 |
+
self.loaded_sounds = {}
|
| 246 |
+
self._load_sound_files()
|
| 247 |
+
|
| 248 |
+
# Cycle statistics
|
| 249 |
+
self.cycle_stats = {
|
| 250 |
+
'total_cycles': 0,
|
| 251 |
+
'successful_classifications': 0,
|
| 252 |
+
'total_attempts': 0
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
def _load_sound_files(self):
|
| 256 |
+
"""Load all available sound files into memory."""
|
| 257 |
+
for class_name, filename in self.current_sound_mapping.items():
|
| 258 |
+
if filename is None:
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
file_path = self.sound_dir / filename
|
| 262 |
+
if file_path.exists():
|
| 263 |
+
try:
|
| 264 |
+
data, sample_rate = sf.read(str(file_path))
|
| 265 |
+
self.loaded_sounds[class_name] = {
|
| 266 |
+
'data': data,
|
| 267 |
+
'sample_rate': sample_rate,
|
| 268 |
+
'file_path': str(file_path)
|
| 269 |
+
}
|
| 270 |
+
print(f"Loaded sound for {class_name}: {filename}")
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"Error loading {filename}: {e}")
|
| 273 |
+
else:
|
| 274 |
+
print(f"Sound file not found: {file_path}")
|
| 275 |
+
|
| 276 |
+
def get_sound_for_class(self, class_name: str) -> Optional[Dict]:
|
| 277 |
+
"""Get sound data for a specific motor imagery class."""
|
| 278 |
+
return self.loaded_sounds.get(class_name)
|
| 279 |
+
|
| 280 |
+
def _generate_new_sequence(self):
|
| 281 |
+
"""Generate a new random sequence of movements for the current cycle."""
|
| 282 |
+
import random
|
| 283 |
+
# Choose which movements to include based on configuration
|
| 284 |
+
movements_for_cycle = self.all_movements.copy() if self.include_neutral_in_cycle else self.active_movements.copy()
|
| 285 |
+
random.shuffle(movements_for_cycle)
|
| 286 |
+
self.current_movement_sequence = movements_for_cycle
|
| 287 |
+
self.movements_completed = set()
|
| 288 |
+
|
| 289 |
+
cycle_size = len(movements_for_cycle)
|
| 290 |
+
print(f"π― New random sequence ({cycle_size} movements): {' β '.join([m.replace('_', ' ').title() for m in self.current_movement_sequence])}")
|
| 291 |
+
|
| 292 |
+
def get_current_target_movement(self) -> str:
|
| 293 |
+
"""Get the current movement the user should imagine."""
|
| 294 |
+
if self.current_step < len(self.current_movement_sequence):
|
| 295 |
+
return self.current_movement_sequence[self.current_step]
|
| 296 |
+
return "cycle_complete"
|
| 297 |
+
|
| 298 |
+
def get_next_random_movement(self) -> str:
|
| 299 |
+
"""Get a random movement from those not yet completed in this cycle."""
|
| 300 |
+
# Use the same movement set as the current cycle
|
| 301 |
+
cycle_movements = self.all_movements if self.include_neutral_in_cycle else self.active_movements
|
| 302 |
+
remaining_movements = [m for m in cycle_movements if m not in self.movements_completed]
|
| 303 |
+
if not remaining_movements:
|
| 304 |
+
return "cycle_complete"
|
| 305 |
+
|
| 306 |
+
import random
|
| 307 |
+
return random.choice(remaining_movements)
|
| 308 |
+
|
| 309 |
+
def process_classification(self, predicted_class: str, confidence: float, threshold: float = 0.7) -> Dict:
|
| 310 |
+
"""
|
| 311 |
+
Process a classification result in the context of the current cycle.
|
| 312 |
+
Uses random movement prompting - user can choose any movement at any time.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
predicted_class: The predicted motor imagery class
|
| 316 |
+
confidence: Confidence score
|
| 317 |
+
threshold: Minimum confidence to add sound layer
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
Dictionary with processing results and next action
|
| 321 |
+
"""
|
| 322 |
+
self.cycle_stats['total_attempts'] += 1
|
| 323 |
+
|
| 324 |
+
# Check if this movement was already completed in this cycle
|
| 325 |
+
already_completed = predicted_class in self.movements_completed
|
| 326 |
+
|
| 327 |
+
result = {
|
| 328 |
+
'predicted_class': predicted_class,
|
| 329 |
+
'confidence': confidence,
|
| 330 |
+
'above_threshold': confidence >= threshold,
|
| 331 |
+
'already_completed': already_completed,
|
| 332 |
+
'sound_added': False,
|
| 333 |
+
'cycle_complete': False,
|
| 334 |
+
'next_action': 'continue',
|
| 335 |
+
'movements_remaining': len(self.current_movement_sequence) - len(self.movements_completed),
|
| 336 |
+
'movements_completed': list(self.movements_completed)
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
# Check if prediction is above threshold and not already completed
|
| 340 |
+
if result['above_threshold'] and not already_completed:
|
| 341 |
+
# Get sound file for this movement
|
| 342 |
+
sound_file = self.current_sound_mapping.get(predicted_class)
|
| 343 |
+
|
| 344 |
+
if sound_file is None:
|
| 345 |
+
# No sound for this movement (e.g., neutral), but still count as completed
|
| 346 |
+
self.movements_completed.add(predicted_class)
|
| 347 |
+
result['sound_added'] = False # No sound, but movement completed
|
| 348 |
+
self.cycle_stats['successful_classifications'] += 1
|
| 349 |
+
cycle_size = len(self.current_movement_sequence)
|
| 350 |
+
print(f"β Cycle {self.current_cycle+1}: {predicted_class} completed (no sound) ({len(self.movements_completed)}/{cycle_size} complete)")
|
| 351 |
+
elif predicted_class in self.loaded_sounds:
|
| 352 |
+
# Add sound layer
|
| 353 |
+
layer_info = {
|
| 354 |
+
'cycle': self.current_cycle,
|
| 355 |
+
'step': len(self.movements_completed), # Step based on number completed
|
| 356 |
+
'movement': predicted_class,
|
| 357 |
+
'sound_file': sound_file,
|
| 358 |
+
'confidence': confidence,
|
| 359 |
+
'timestamp': time.time(),
|
| 360 |
+
'sound_data': self.loaded_sounds[predicted_class]
|
| 361 |
+
}
|
| 362 |
+
self.composition_layers.append(layer_info)
|
| 363 |
+
self.movements_completed.add(predicted_class)
|
| 364 |
+
result['sound_added'] = True
|
| 365 |
+
|
| 366 |
+
print(f"DEBUG: Added layer {len(self.composition_layers)}, total layers now: {len(self.composition_layers)}")
|
| 367 |
+
print(f"DEBUG: Composition layers: {[layer['movement'] for layer in self.composition_layers]}")
|
| 368 |
+
|
| 369 |
+
# Return individual sound file for this classification
|
| 370 |
+
sound_path = os.path.join(self.sound_dir, sound_file)
|
| 371 |
+
result['audio_file'] = sound_path if os.path.exists(sound_path) else None
|
| 372 |
+
|
| 373 |
+
# Also create mixed composition for potential saving (but don't return it)
|
| 374 |
+
mixed_file = self.get_current_mixed_composition()
|
| 375 |
+
result['mixed_composition'] = mixed_file
|
| 376 |
+
|
| 377 |
+
self.cycle_stats['successful_classifications'] += 1
|
| 378 |
+
cycle_size = len(self.current_movement_sequence)
|
| 379 |
+
print(f"β Cycle {self.current_cycle+1}: Added {sound_file} for {predicted_class} ({len(self.movements_completed)}/{cycle_size} complete)")
|
| 380 |
+
|
| 381 |
+
# Check if cycle is complete (all movements in current sequence completed)
|
| 382 |
+
cycle_movements = self.all_movements if self.include_neutral_in_cycle else self.active_movements
|
| 383 |
+
if len(self.movements_completed) >= len(cycle_movements):
|
| 384 |
+
result['cycle_complete'] = True
|
| 385 |
+
result['next_action'] = 'remap_sounds'
|
| 386 |
+
self.cycle_complete = True
|
| 387 |
+
cycle_size = len(cycle_movements)
|
| 388 |
+
print(f"π΅ Cycle {self.current_cycle+1} complete! All {cycle_size} movements successfully classified!")
|
| 389 |
+
|
| 390 |
+
return result
|
| 391 |
+
|
| 392 |
+
def start_new_cycle(self, new_sound_mapping: Dict[str, str] = None):
|
| 393 |
+
"""Start a new composition cycle with optional new sound mapping."""
|
| 394 |
+
if new_sound_mapping:
|
| 395 |
+
self.current_sound_mapping.update(new_sound_mapping)
|
| 396 |
+
print(f"Updated sound mapping for cycle {self.current_cycle+2}")
|
| 397 |
+
|
| 398 |
+
self.current_cycle += 1
|
| 399 |
+
self.current_step = 0
|
| 400 |
+
self.cycle_complete = False
|
| 401 |
+
self.cycle_stats['total_cycles'] += 1
|
| 402 |
+
|
| 403 |
+
# Generate new random sequence for this cycle
|
| 404 |
+
self._generate_new_sequence()
|
| 405 |
+
|
| 406 |
+
print(f"π Starting Cycle {self.current_cycle+1}")
|
| 407 |
+
if self.current_cycle == 1:
|
| 408 |
+
print("πͺ Let's create your first brain-music composition!")
|
| 409 |
+
elif self.current_cycle == 2:
|
| 410 |
+
print("οΏ½ Great progress! Let's create your second composition!")
|
| 411 |
+
else:
|
| 412 |
+
print("οΏ½π― Imagine ANY movement - you can choose the order!")
|
| 413 |
+
|
| 414 |
+
def should_end_session(self) -> bool:
|
| 415 |
+
"""Check if the rehabilitation session should end after max cycles."""
|
| 416 |
+
return self.completed_cycles >= self.max_cycles
|
| 417 |
+
|
| 418 |
+
def complete_current_cycle(self):
|
| 419 |
+
"""Mark current cycle as complete and track progress."""
|
| 420 |
+
self.completed_cycles += 1
|
| 421 |
+
print(f"β
Cycle {self.current_cycle} completed! ({self.completed_cycles}/{self.max_cycles} compositions finished)")
|
| 422 |
+
|
| 423 |
+
def transition_to_dj_phase(self):
|
| 424 |
+
"""Transition from building phase to DJ effects phase."""
|
| 425 |
+
if len(self.movements_completed) >= 5: # All movements completed
|
| 426 |
+
self.current_phase = "dj_effects"
|
| 427 |
+
# Create mixed composition for DJ effects
|
| 428 |
+
self._create_mixed_composition()
|
| 429 |
+
print("π΅ Composition Complete! Transitioning to DJ Effects Phase...")
|
| 430 |
+
print("π§ You are now the DJ! Use movements to control effects:")
|
| 431 |
+
print(" π Left Hand: Volume Fade")
|
| 432 |
+
print(" π Right Hand: High Pass Filter")
|
| 433 |
+
print(" 𦡠Left Leg: Reverb Effect")
|
| 434 |
+
print(" 𦡠Right Leg: Low Pass Filter")
|
| 435 |
+
print(" π
Tongue: Bass Boost")
|
| 436 |
+
return True
|
| 437 |
+
return False
|
| 438 |
+
|
| 439 |
+
def _create_mixed_composition(self):
|
| 440 |
+
"""Create a mixed audio file from all completed layers."""
|
| 441 |
+
try:
|
| 442 |
+
import hashlib
|
| 443 |
+
movement_hash = hashlib.md5(str(sorted(self.movements_completed)).encode()).hexdigest()[:8]
|
| 444 |
+
self.mixed_composition_file = os.path.abspath(f"mixed_composition_{movement_hash}.wav")
|
| 445 |
+
|
| 446 |
+
# FILE SAVING DISABLED: Use existing base audio file instead
|
| 447 |
+
# Try to use the first available completed movement's audio file
|
| 448 |
+
for movement in self.movements_completed:
|
| 449 |
+
if movement in self.current_sound_mapping and self.current_sound_mapping[movement] is not None:
|
| 450 |
+
sound_file = os.path.join(self.sound_dir, self.current_sound_mapping[movement])
|
| 451 |
+
if os.path.exists(sound_file):
|
| 452 |
+
self.mixed_composition_file = os.path.abspath(sound_file)
|
| 453 |
+
print(f"π Using existing audio as mixed composition: {self.mixed_composition_file} (FILE SAVING DISABLED)")
|
| 454 |
+
return
|
| 455 |
+
|
| 456 |
+
# If file already exists, use it
|
| 457 |
+
if os.path.exists(self.mixed_composition_file):
|
| 458 |
+
print(f"π Using existing mixed composition: {self.mixed_composition_file}")
|
| 459 |
+
return
|
| 460 |
+
|
| 461 |
+
# Create actual mixed composition by layering completed sounds
|
| 462 |
+
mixed_data = None
|
| 463 |
+
sample_rate = 44100 # Default sample rate
|
| 464 |
+
|
| 465 |
+
for movement in self.movements_completed:
|
| 466 |
+
if movement in self.current_sound_mapping and self.current_sound_mapping[movement] is not None:
|
| 467 |
+
sound_file = os.path.join(self.sound_dir, self.current_sound_mapping[movement])
|
| 468 |
+
if os.path.exists(sound_file):
|
| 469 |
+
try:
|
| 470 |
+
data, sr = sf.read(sound_file)
|
| 471 |
+
sample_rate = sr
|
| 472 |
+
|
| 473 |
+
# Convert stereo to mono
|
| 474 |
+
if len(data.shape) > 1:
|
| 475 |
+
data = np.mean(data, axis=1)
|
| 476 |
+
|
| 477 |
+
# Initialize or add to mixed data
|
| 478 |
+
if mixed_data is None:
|
| 479 |
+
mixed_data = data * 0.8 # Reduce volume to prevent clipping
|
| 480 |
+
else:
|
| 481 |
+
# Ensure same length by padding shorter audio
|
| 482 |
+
if len(data) > len(mixed_data):
|
| 483 |
+
mixed_data = np.pad(mixed_data, (0, len(data) - len(mixed_data)))
|
| 484 |
+
elif len(mixed_data) > len(data):
|
| 485 |
+
data = np.pad(data, (0, len(mixed_data) - len(data)))
|
| 486 |
+
|
| 487 |
+
# Mix the audio (layer them)
|
| 488 |
+
mixed_data += data * 0.8
|
| 489 |
+
except Exception as e:
|
| 490 |
+
print(f"Error mixing {sound_file}: {e}")
|
| 491 |
+
|
| 492 |
+
# Save mixed composition or create silent fallback
|
| 493 |
+
if mixed_data is not None:
|
| 494 |
+
# Normalize to prevent clipping
|
| 495 |
+
max_val = np.max(np.abs(mixed_data))
|
| 496 |
+
if max_val > 0.95:
|
| 497 |
+
mixed_data = mixed_data * 0.95 / max_val
|
| 498 |
+
|
| 499 |
+
# sf.write(self.mixed_composition_file, mixed_data, sample_rate)
|
| 500 |
+
print(f"π Mixed composition created: {self.mixed_composition_file} (FILE SAVING DISABLED)")
|
| 501 |
+
else:
|
| 502 |
+
# Create silent fallback file
|
| 503 |
+
silent_data = np.zeros(int(sample_rate * 2)) # 2 seconds of silence
|
| 504 |
+
# sf.write(self.mixed_composition_file, silent_data, sample_rate)
|
| 505 |
+
print(f"π Silent fallback composition created: {self.mixed_composition_file} (FILE SAVING DISABLED)")
|
| 506 |
+
|
| 507 |
+
except Exception as e:
|
| 508 |
+
print(f"Error creating mixed composition: {e}")
|
| 509 |
+
# Create minimal fallback file with actual content
|
| 510 |
+
self.mixed_composition_file = os.path.abspath("mixed_composition_fallback.wav")
|
| 511 |
+
try:
|
| 512 |
+
# Create a short silent audio file as fallback
|
| 513 |
+
sample_rate = 44100
|
| 514 |
+
silent_data = np.zeros(int(sample_rate * 2)) # 2 seconds of silence
|
| 515 |
+
# sf.write(self.mixed_composition_file, silent_data, sample_rate)
|
| 516 |
+
print(f"π Silent fallback composition created: {self.mixed_composition_file} (FILE SAVING DISABLED)")
|
| 517 |
+
except Exception as fallback_error:
|
| 518 |
+
print(f"Failed to create fallback file: {fallback_error}")
|
| 519 |
+
self.mixed_composition_file = None
|
| 520 |
+
|
| 521 |
+
def toggle_dj_effect(self, movement: str) -> dict:
|
| 522 |
+
"""Toggle a DJ effect for the given movement and process audio."""
|
| 523 |
+
if self.current_phase != "dj_effects":
|
| 524 |
+
return {"effect_applied": False, "message": "Not in DJ effects phase"}
|
| 525 |
+
|
| 526 |
+
if movement not in self.active_effects:
|
| 527 |
+
return {"effect_applied": False, "message": f"Unknown movement: {movement}"}
|
| 528 |
+
|
| 529 |
+
# Toggle the effect
|
| 530 |
+
self.active_effects[movement] = not self.active_effects[movement]
|
| 531 |
+
effect_status = "ON" if self.active_effects[movement] else "OFF"
|
| 532 |
+
|
| 533 |
+
effect_names = {
|
| 534 |
+
"left_hand": "Volume Fade",
|
| 535 |
+
"right_hand": "High Pass Filter",
|
| 536 |
+
"left_leg": "Reverb Effect",
|
| 537 |
+
"right_leg": "Low Pass Filter",
|
| 538 |
+
"tongue": "Bass Boost"
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
effect_name = effect_names.get(movement, movement)
|
| 542 |
+
print(f"ποΈ {effect_name}: {effect_status}")
|
| 543 |
+
|
| 544 |
+
# Process audio with current active effects
|
| 545 |
+
if self.mixed_composition_file and os.path.exists(self.mixed_composition_file):
|
| 546 |
+
processed_file = AudioEffectsProcessor.process_with_effects(
|
| 547 |
+
self.mixed_composition_file,
|
| 548 |
+
self.active_effects
|
| 549 |
+
)
|
| 550 |
+
else:
|
| 551 |
+
# If no mixed composition exists, create one from current sounds
|
| 552 |
+
self._create_mixed_composition()
|
| 553 |
+
# Only process if we successfully created a mixed composition
|
| 554 |
+
if self.mixed_composition_file and os.path.exists(self.mixed_composition_file):
|
| 555 |
+
processed_file = AudioEffectsProcessor.process_with_effects(
|
| 556 |
+
self.mixed_composition_file,
|
| 557 |
+
self.active_effects
|
| 558 |
+
)
|
| 559 |
+
else:
|
| 560 |
+
print("Failed to create mixed composition, using fallback")
|
| 561 |
+
processed_file = self.mixed_composition_file
|
| 562 |
+
|
| 563 |
+
return {
|
| 564 |
+
"effect_applied": True,
|
| 565 |
+
"effect_name": effect_name,
|
| 566 |
+
"effect_status": effect_status,
|
| 567 |
+
"mixed_composition": processed_file
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
def get_cycle_success_rate(self) -> float:
|
| 571 |
+
"""Get success rate for current cycle."""
|
| 572 |
+
if self.cycle_stats['total_attempts'] == 0:
|
| 573 |
+
return 0.0
|
| 574 |
+
return self.cycle_stats['successful_classifications'] / self.cycle_stats['total_attempts']
|
| 575 |
+
|
| 576 |
+
def clear_composition(self):
|
| 577 |
+
"""Clear all composition layers."""
|
| 578 |
+
self.composition_layers = []
|
| 579 |
+
self.current_composition = None
|
| 580 |
+
print("Composition cleared")
|
| 581 |
+
|
| 582 |
+
def _get_audio_file_for_gradio(self, movement):
|
| 583 |
+
"""Get the actual audio file path for Gradio audio output"""
|
| 584 |
+
if movement in self.current_sound_mapping:
|
| 585 |
+
sound_file = self.current_sound_mapping[movement]
|
| 586 |
+
audio_path = os.path.join(self.sound_dir, sound_file)
|
| 587 |
+
if os.path.exists(audio_path):
|
| 588 |
+
return audio_path
|
| 589 |
+
return None
|
| 590 |
+
|
| 591 |
+
def _mix_audio_files(self, audio_files: List[str]) -> str:
|
| 592 |
+
"""Mix multiple audio files into a single layered audio file."""
|
| 593 |
+
if not audio_files:
|
| 594 |
+
return None
|
| 595 |
+
|
| 596 |
+
# Load all audio files and convert to mono
|
| 597 |
+
audio_data_list = []
|
| 598 |
+
sample_rate = None
|
| 599 |
+
max_length = 0
|
| 600 |
+
|
| 601 |
+
for file_path in audio_files:
|
| 602 |
+
if os.path.exists(file_path):
|
| 603 |
+
data, sr = sf.read(file_path)
|
| 604 |
+
|
| 605 |
+
# Ensure data is numpy array and handle shape properly
|
| 606 |
+
data = np.asarray(data)
|
| 607 |
+
|
| 608 |
+
# Convert to mono if stereo/multi-channel
|
| 609 |
+
if data.ndim > 1:
|
| 610 |
+
if data.shape[1] > 1: # Multi-channel
|
| 611 |
+
data = np.mean(data, axis=1)
|
| 612 |
+
else: # Single channel with extra dimension
|
| 613 |
+
data = data.flatten()
|
| 614 |
+
|
| 615 |
+
# Ensure final data is 1D
|
| 616 |
+
data = np.asarray(data).flatten()
|
| 617 |
+
|
| 618 |
+
audio_data_list.append(data)
|
| 619 |
+
if sample_rate is None:
|
| 620 |
+
sample_rate = sr
|
| 621 |
+
max_length = max(max_length, len(data))
|
| 622 |
+
|
| 623 |
+
if not audio_data_list:
|
| 624 |
+
return None
|
| 625 |
+
|
| 626 |
+
# Pad all audio to same length and mix
|
| 627 |
+
mixed_audio = np.zeros(max_length)
|
| 628 |
+
for data in audio_data_list:
|
| 629 |
+
# Ensure data is 1D (flatten any remaining multi-dimensional arrays)
|
| 630 |
+
data_flat = np.asarray(data).flatten()
|
| 631 |
+
|
| 632 |
+
# Pad or truncate to match max_length
|
| 633 |
+
if len(data_flat) < max_length:
|
| 634 |
+
padded = np.pad(data_flat, (0, max_length - len(data_flat)), 'constant')
|
| 635 |
+
else:
|
| 636 |
+
padded = data_flat[:max_length]
|
| 637 |
+
|
| 638 |
+
# Add to mix (normalize to prevent clipping)
|
| 639 |
+
mixed_audio += padded / len(audio_data_list)
|
| 640 |
+
|
| 641 |
+
# Create a unique identifier for this composition
|
| 642 |
+
import hashlib
|
| 643 |
+
composition_hash = hashlib.md5(''.join(sorted(audio_files)).encode()).hexdigest()[:8]
|
| 644 |
+
|
| 645 |
+
# FILE SAVING DISABLED: Return first available audio file instead of creating mixed composition
|
| 646 |
+
mixed_audio_path = os.path.join(self.sound_dir, f"mixed_composition_{composition_hash}.wav")
|
| 647 |
+
|
| 648 |
+
# Since file saving is disabled, use the first available audio file from the list
|
| 649 |
+
if audio_files:
|
| 650 |
+
# Use the first audio file as the "mixed" composition
|
| 651 |
+
first_audio_file = os.path.join(self.sound_dir, audio_files[0])
|
| 652 |
+
if os.path.exists(first_audio_file):
|
| 653 |
+
print(f"DEBUG: Using first audio file as mixed composition: {os.path.basename(first_audio_file)} (FILE SAVING DISABLED)")
|
| 654 |
+
return first_audio_file
|
| 655 |
+
|
| 656 |
+
# Fallback: if mixed composition file already exists, use it
|
| 657 |
+
if os.path.exists(mixed_audio_path):
|
| 658 |
+
print(f"DEBUG: Reusing existing mixed audio file: {os.path.basename(mixed_audio_path)}")
|
| 659 |
+
return mixed_audio_path
|
| 660 |
+
|
| 661 |
+
# Final fallback: return first available base audio file
|
| 662 |
+
base_files = ["1_SoundHelix-Song-6_(Vocals).wav", "1_SoundHelix-Song-6_(Drums).wav", "1_SoundHelix-Song-6_(Bass).wav", "1_SoundHelix-Song-6_(Other).wav"]
|
| 663 |
+
for base_file in base_files:
|
| 664 |
+
base_path = os.path.join(self.sound_dir, base_file)
|
| 665 |
+
if os.path.exists(base_path):
|
| 666 |
+
print(f"DEBUG: Using base audio file as fallback: {base_file} (FILE SAVING DISABLED)")
|
| 667 |
+
return base_path
|
| 668 |
+
|
| 669 |
+
return mixed_audio_path
|
| 670 |
+
|
| 671 |
+
def get_current_mixed_composition(self) -> str:
|
| 672 |
+
"""Get the current composition as a mixed audio file."""
|
| 673 |
+
# Get all audio files from current composition layers
|
| 674 |
+
audio_files = []
|
| 675 |
+
for layer in self.composition_layers:
|
| 676 |
+
movement = layer.get('movement')
|
| 677 |
+
if movement and movement in self.current_sound_mapping:
|
| 678 |
+
sound_file = self.current_sound_mapping[movement]
|
| 679 |
+
audio_path = os.path.join(self.sound_dir, sound_file)
|
| 680 |
+
if os.path.exists(audio_path):
|
| 681 |
+
audio_files.append(audio_path)
|
| 682 |
+
|
| 683 |
+
# Debug: print current composition state
|
| 684 |
+
print(f"DEBUG: Current composition has {len(self.composition_layers)} layers: {[layer.get('movement') for layer in self.composition_layers]}")
|
| 685 |
+
print(f"DEBUG: Audio files to mix: {[os.path.basename(f) for f in audio_files]}")
|
| 686 |
+
|
| 687 |
+
return self._mix_audio_files(audio_files)
|
| 688 |
+
|
| 689 |
+
def get_composition_info(self) -> Dict:
|
| 690 |
+
"""Get comprehensive information about current composition."""
|
| 691 |
+
layers_by_cycle = {}
|
| 692 |
+
for layer in self.composition_layers:
|
| 693 |
+
cycle = layer['cycle']
|
| 694 |
+
if cycle not in layers_by_cycle:
|
| 695 |
+
layers_by_cycle[cycle] = []
|
| 696 |
+
layers_by_cycle[cycle].append({
|
| 697 |
+
'step': layer['step'],
|
| 698 |
+
'movement': layer['movement'],
|
| 699 |
+
'sound_file': layer['sound_file'],
|
| 700 |
+
'confidence': layer['confidence']
|
| 701 |
+
})
|
| 702 |
+
|
| 703 |
+
# Also track completed movements without sounds (like neutral)
|
| 704 |
+
completed_without_sound = [mov for mov in self.movements_completed
|
| 705 |
+
if self.current_sound_mapping.get(mov) is None]
|
| 706 |
+
|
| 707 |
+
return {
|
| 708 |
+
'total_cycles': self.current_cycle + (1 if self.composition_layers else 0),
|
| 709 |
+
'current_cycle': self.current_cycle + 1,
|
| 710 |
+
'current_step': len(self.movements_completed) + 1, # Current step in cycle
|
| 711 |
+
'target_movement': self.get_current_target_movement(),
|
| 712 |
+
'movements_completed': len(self.movements_completed),
|
| 713 |
+
'movements_remaining': len(self.current_movement_sequence) - len(self.movements_completed),
|
| 714 |
+
'cycle_complete': self.cycle_complete,
|
| 715 |
+
'total_layers': len(self.composition_layers),
|
| 716 |
+
'completed_movements': list(self.movements_completed),
|
| 717 |
+
'completed_without_sound': completed_without_sound,
|
| 718 |
+
'layers_by_cycle': layers_by_cycle,
|
| 719 |
+
'current_mapping': self.current_sound_mapping.copy(),
|
| 720 |
+
'success_rate': self.get_cycle_success_rate(),
|
| 721 |
+
'stats': self.cycle_stats.copy()
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
def get_sound_mapping_options(self) -> Dict:
|
| 725 |
+
"""Get available sound mapping options for user customization."""
|
| 726 |
+
return {
|
| 727 |
+
'movements': self.all_movements, # All possible movements for mapping
|
| 728 |
+
'available_sounds': self.available_sounds,
|
| 729 |
+
'current_mapping': self.current_sound_mapping.copy()
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
def update_sound_mapping(self, new_mapping: Dict[str, str]) -> bool:
|
| 733 |
+
"""Update the sound mapping for future cycles."""
|
| 734 |
+
try:
|
| 735 |
+
# Validate that all sounds exist
|
| 736 |
+
for movement, sound_file in new_mapping.items():
|
| 737 |
+
if sound_file not in self.available_sounds:
|
| 738 |
+
print(f"Warning: Sound file {sound_file} not available")
|
| 739 |
+
return False
|
| 740 |
+
|
| 741 |
+
self.current_sound_mapping.update(new_mapping)
|
| 742 |
+
print("β Sound mapping updated successfully")
|
| 743 |
+
return True
|
| 744 |
+
except Exception as e:
|
| 745 |
+
print(f"Error updating sound mapping: {e}")
|
| 746 |
+
return False
|
| 747 |
+
|
| 748 |
+
def reset_composition(self):
|
| 749 |
+
"""Reset the entire composition to start fresh."""
|
| 750 |
+
self.composition_layers = []
|
| 751 |
+
self.current_cycle = 0
|
| 752 |
+
self.current_step = 0
|
| 753 |
+
self.cycle_complete = False
|
| 754 |
+
self.cycle_stats = {
|
| 755 |
+
'total_cycles': 0,
|
| 756 |
+
'successful_classifications': 0,
|
| 757 |
+
'total_attempts': 0
|
| 758 |
+
}
|
| 759 |
+
print("π Composition reset - ready for new session")
|
| 760 |
+
|
| 761 |
+
def save_composition(self, output_path: str = "composition.wav"):
|
| 762 |
+
"""Save the current composition as a mixed audio file."""
|
| 763 |
+
if not self.composition_layers:
|
| 764 |
+
print("No layers to save")
|
| 765 |
+
return False
|
| 766 |
+
|
| 767 |
+
try:
|
| 768 |
+
# For simplicity, we'll just save the latest layer
|
| 769 |
+
# In a real implementation, you'd mix multiple audio tracks
|
| 770 |
+
latest_layer = self.composition_layers[-1]
|
| 771 |
+
sound_data = latest_layer['sound_data']
|
| 772 |
+
|
| 773 |
+
sf.write(output_path, sound_data['data'], sound_data['sample_rate'])
|
| 774 |
+
print(f"Composition saved to {output_path}")
|
| 775 |
+
return True
|
| 776 |
+
|
| 777 |
+
except Exception as e:
|
| 778 |
+
print(f"Error saving composition: {e}")
|
| 779 |
+
return False
|
| 780 |
+
|
| 781 |
+
def get_available_sounds(self) -> List[str]:
|
| 782 |
+
"""Get list of available sound classes."""
|
| 783 |
+
return list(self.loaded_sounds.keys())
|
source/eeg_motor_imagery.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EEG Motor Imagery Classification with Shallow ConvNet
|
| 3 |
+
-----------------------------------------------------
|
| 4 |
+
This script trains and evaluates a ShallowFBCSPNet model
|
| 5 |
+
on motor imagery EEG data stored in .mat files.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
# === 1. Imports ===
|
| 9 |
+
import scipy.io
|
| 10 |
+
import numpy as np
|
| 11 |
+
import mne
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 14 |
+
from sklearn.model_selection import train_test_split
|
| 15 |
+
from braindecode.models import ShallowFBCSPNet
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
# === 2. Data Loading and Epoching ===
|
| 20 |
+
# Load .mat EEG files, create Raw objects, extract events, and epoch the data
|
| 21 |
+
files = [
|
| 22 |
+
"../data/raw_mat/HaLTSubjectA1602236StLRHandLegTongue.mat",
|
| 23 |
+
"../data/raw_mat/HaLTSubjectA1603086StLRHandLegTongue.mat",
|
| 24 |
+
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
all_epochs = []
|
| 28 |
+
for f in files:
|
| 29 |
+
mat = scipy.io.loadmat(f)
|
| 30 |
+
content = mat['o'][0, 0]
|
| 31 |
+
|
| 32 |
+
labels = content[4].flatten()
|
| 33 |
+
signals = content[5]
|
| 34 |
+
chan_names_raw = content[6]
|
| 35 |
+
channels = [ch[0][0] for ch in chan_names_raw]
|
| 36 |
+
fs = int(content[2][0, 0])
|
| 37 |
+
|
| 38 |
+
df = pd.DataFrame(signals, columns=channels).drop(columns=["X5"], errors="ignore")
|
| 39 |
+
eeg = df.values.T
|
| 40 |
+
ch_names = df.columns.tolist()
|
| 41 |
+
|
| 42 |
+
info = mne.create_info(ch_names=ch_names, sfreq=fs, ch_types="eeg")
|
| 43 |
+
raw = mne.io.RawArray(eeg, info)
|
| 44 |
+
|
| 45 |
+
# Create events
|
| 46 |
+
onsets = np.where((labels[1:] != 0) & (labels[:-1] == 0))[0] + 1
|
| 47 |
+
event_codes = labels[onsets].astype(int)
|
| 48 |
+
events = np.c_[onsets, np.zeros_like(onsets), event_codes]
|
| 49 |
+
|
| 50 |
+
# Keep only relevant events
|
| 51 |
+
mask = np.isin(events[:, 2], np.arange(1, 7))
|
| 52 |
+
events = events[mask]
|
| 53 |
+
|
| 54 |
+
event_id = {
|
| 55 |
+
"left_hand": 1,
|
| 56 |
+
"right_hand": 2,
|
| 57 |
+
"neutral": 3,
|
| 58 |
+
"left_leg": 4,
|
| 59 |
+
"tongue": 5,
|
| 60 |
+
"right_leg": 6,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
# Epoching
|
| 64 |
+
epochs = mne.Epochs(
|
| 65 |
+
raw,
|
| 66 |
+
events=events,
|
| 67 |
+
event_id=event_id,
|
| 68 |
+
tmin=0,
|
| 69 |
+
tmax=1.5,
|
| 70 |
+
baseline=None,
|
| 71 |
+
preload=True,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
all_epochs.append(epochs)
|
| 75 |
+
|
| 76 |
+
epochs_all = mne.concatenate_epochs(all_epochs)
|
| 77 |
+
|
| 78 |
+
# === 3. Minimal Preprocessing + Train/Validation Split ===
|
| 79 |
+
# Convert epochs to numpy arrays (N, C, T) and split into train/val sets
|
| 80 |
+
X = epochs_all.get_data().astype("float32")
|
| 81 |
+
y = (epochs_all.events[:, -1] - 1).astype("int64") # classes 0..5
|
| 82 |
+
|
| 83 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 84 |
+
X, y, test_size=0.2, random_state=42, stratify=y
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# === 4. Torch DataLoaders ===
|
| 88 |
+
train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
|
| 89 |
+
val_ds = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
|
| 90 |
+
|
| 91 |
+
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
|
| 92 |
+
val_loader = DataLoader(val_ds, batch_size=32)
|
| 93 |
+
|
| 94 |
+
# === 5. Model β Shallow ConvNet ===
|
| 95 |
+
# Reference: Schirrmeister et al. (2017)
|
| 96 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 97 |
+
|
| 98 |
+
model = ShallowFBCSPNet(
|
| 99 |
+
n_chans=X.shape[1],
|
| 100 |
+
n_outputs=len(np.unique(y)),
|
| 101 |
+
n_times=X.shape[2],
|
| 102 |
+
final_conv_length="auto"
|
| 103 |
+
).to(device)
|
| 104 |
+
|
| 105 |
+
# Load pretrained weights
|
| 106 |
+
state_dict = torch.load("shallow_weights_all.pth", map_location=device)
|
| 107 |
+
model.load_state_dict(state_dict)
|
| 108 |
+
|
| 109 |
+
# === 6. Training ===
|
| 110 |
+
criterion = nn.CrossEntropyLoss()
|
| 111 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 112 |
+
|
| 113 |
+
for epoch in range(1, 21):
|
| 114 |
+
# Training
|
| 115 |
+
model.train()
|
| 116 |
+
correct, total = 0, 0
|
| 117 |
+
for xb, yb in train_loader:
|
| 118 |
+
xb, yb = xb.to(device), yb.to(device)
|
| 119 |
+
optimizer.zero_grad()
|
| 120 |
+
out = model(xb)
|
| 121 |
+
loss = criterion(out, yb)
|
| 122 |
+
loss.backward()
|
| 123 |
+
optimizer.step()
|
| 124 |
+
|
| 125 |
+
pred = out.argmax(dim=1)
|
| 126 |
+
correct += (pred == yb).sum().item()
|
| 127 |
+
total += yb.size(0)
|
| 128 |
+
train_acc = correct / total
|
| 129 |
+
|
| 130 |
+
# Validation
|
| 131 |
+
model.eval()
|
| 132 |
+
correct, total = 0, 0
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
for xb, yb in val_loader:
|
| 135 |
+
xb, yb = xb.to(device), yb.to(device)
|
| 136 |
+
out = model(xb)
|
| 137 |
+
pred = out.argmax(dim=1)
|
| 138 |
+
correct += (pred == yb).sum().item()
|
| 139 |
+
total += yb.size(0)
|
| 140 |
+
val_acc = correct / total
|
| 141 |
+
|
| 142 |
+
print(f"Epoch {epoch:02d} | Train acc: {train_acc:.3f} | Val acc: {val_acc:.3f}")
|
utils.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for the EEG Motor Imagery Music Composer
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Optional, Tuple
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
from config import LOG_LEVEL, LOG_FILE, CLASS_NAMES, CLASS_DESCRIPTIONS
|
| 13 |
+
|
| 14 |
+
def setup_logging():
|
| 15 |
+
"""Set up logging configuration."""
|
| 16 |
+
LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(
|
| 19 |
+
level=getattr(logging, LOG_LEVEL),
|
| 20 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 21 |
+
handlers=[
|
| 22 |
+
logging.FileHandler(LOG_FILE),
|
| 23 |
+
logging.StreamHandler()
|
| 24 |
+
]
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
return logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
def validate_eeg_data(data: np.ndarray) -> bool:
|
| 30 |
+
"""
|
| 31 |
+
Validate EEG data format and dimensions.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
data: EEG data array
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
bool: True if data is valid, False otherwise
|
| 38 |
+
"""
|
| 39 |
+
if not isinstance(data, np.ndarray):
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
if data.ndim not in [2, 3]:
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
if data.ndim == 2 and data.shape[0] == 0:
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
if data.ndim == 3 and (data.shape[0] == 0 or data.shape[1] == 0):
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
return True
|
| 52 |
+
|
| 53 |
+
def format_confidence(confidence: float) -> str:
|
| 54 |
+
"""Format confidence score as percentage string."""
|
| 55 |
+
return f"{confidence * 100:.1f}%"
|
| 56 |
+
|
| 57 |
+
def format_timestamp(timestamp: float) -> str:
|
| 58 |
+
"""Format timestamp for display."""
|
| 59 |
+
return time.strftime("%H:%M:%S", time.localtime(timestamp))
|
| 60 |
+
|
| 61 |
+
def get_class_emoji(class_name: str) -> str:
|
| 62 |
+
"""Get emoji representation for motor imagery class."""
|
| 63 |
+
emoji_map = {
|
| 64 |
+
"left_hand": "π€",
|
| 65 |
+
"right_hand": "π€",
|
| 66 |
+
"neutral": "π",
|
| 67 |
+
"left_leg": "π¦΅",
|
| 68 |
+
"tongue": "π
",
|
| 69 |
+
"right_leg": "π¦΅"
|
| 70 |
+
}
|
| 71 |
+
return emoji_map.get(class_name, "β")
|
| 72 |
+
|
| 73 |
+
def create_classification_summary(
|
| 74 |
+
predicted_class: str,
|
| 75 |
+
confidence: float,
|
| 76 |
+
probabilities: Dict[str, float],
|
| 77 |
+
timestamp: Optional[float] = None
|
| 78 |
+
) -> Dict:
|
| 79 |
+
"""
|
| 80 |
+
Create a formatted summary of classification results.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
predicted_class: Predicted motor imagery class
|
| 84 |
+
confidence: Confidence score (0-1)
|
| 85 |
+
probabilities: Dictionary of class probabilities
|
| 86 |
+
timestamp: Optional timestamp
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Dict: Formatted classification summary
|
| 90 |
+
"""
|
| 91 |
+
if timestamp is None:
|
| 92 |
+
timestamp = time.time()
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
"predicted_class": predicted_class,
|
| 96 |
+
"confidence": confidence,
|
| 97 |
+
"confidence_percent": format_confidence(confidence),
|
| 98 |
+
"probabilities": probabilities,
|
| 99 |
+
"timestamp": timestamp,
|
| 100 |
+
"formatted_time": format_timestamp(timestamp),
|
| 101 |
+
"emoji": get_class_emoji(predicted_class),
|
| 102 |
+
"description": CLASS_DESCRIPTIONS.get(predicted_class, predicted_class)
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
def save_session_data(session_data: Dict, filepath: str) -> bool:
|
| 106 |
+
"""
|
| 107 |
+
Save session data to JSON file.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
session_data: Dictionary containing session information
|
| 111 |
+
filepath: Path to save the file
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
bool: True if successful, False otherwise
|
| 115 |
+
"""
|
| 116 |
+
try:
|
| 117 |
+
with open(filepath, 'w') as f:
|
| 118 |
+
json.dump(session_data, f, indent=2, default=str)
|
| 119 |
+
return True
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logging.error(f"Error saving session data: {e}")
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
def load_session_data(filepath: str) -> Optional[Dict]:
|
| 125 |
+
"""
|
| 126 |
+
Load session data from JSON file.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
filepath: Path to the JSON file
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Dict or None: Session data if successful, None otherwise
|
| 133 |
+
"""
|
| 134 |
+
try:
|
| 135 |
+
with open(filepath, 'r') as f:
|
| 136 |
+
return json.load(f)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logging.error(f"Error loading session data: {e}")
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
def calculate_classification_statistics(history: List[Dict]) -> Dict:
|
| 142 |
+
"""
|
| 143 |
+
Calculate statistics from classification history.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
history: List of classification results
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Dict: Statistics summary
|
| 150 |
+
"""
|
| 151 |
+
if not history:
|
| 152 |
+
return {"total": 0, "class_counts": {}, "average_confidence": 0.0}
|
| 153 |
+
|
| 154 |
+
class_counts = {}
|
| 155 |
+
total_confidence = 0.0
|
| 156 |
+
|
| 157 |
+
for item in history:
|
| 158 |
+
class_name = item.get("predicted_class", "unknown")
|
| 159 |
+
confidence = item.get("confidence", 0.0)
|
| 160 |
+
|
| 161 |
+
class_counts[class_name] = class_counts.get(class_name, 0) + 1
|
| 162 |
+
total_confidence += confidence
|
| 163 |
+
|
| 164 |
+
return {
|
| 165 |
+
"total": len(history),
|
| 166 |
+
"class_counts": class_counts,
|
| 167 |
+
"average_confidence": total_confidence / len(history),
|
| 168 |
+
"most_common_class": max(class_counts, key=class_counts.get) if class_counts else None
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
def create_progress_bar(value: float, max_value: float = 1.0, width: int = 20) -> str:
|
| 172 |
+
"""
|
| 173 |
+
Create a text-based progress bar.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
value: Current value
|
| 177 |
+
max_value: Maximum value
|
| 178 |
+
width: Width of progress bar in characters
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
str: Progress bar string
|
| 182 |
+
"""
|
| 183 |
+
percentage = min(value / max_value, 1.0)
|
| 184 |
+
filled = int(width * percentage)
|
| 185 |
+
bar = "β" * filled + "β" * (width - filled)
|
| 186 |
+
return f"[{bar}] {percentage * 100:.1f}%"
|
| 187 |
+
|
| 188 |
+
def validate_audio_file(file_path: str) -> bool:
|
| 189 |
+
"""
|
| 190 |
+
Validate if an audio file exists and is readable.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
file_path: Path to audio file
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
bool: True if file is valid, False otherwise
|
| 197 |
+
"""
|
| 198 |
+
path = Path(file_path)
|
| 199 |
+
if not path.exists():
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
if not path.is_file():
|
| 203 |
+
return False
|
| 204 |
+
|
| 205 |
+
# Check file extension
|
| 206 |
+
valid_extensions = ['.wav', '.mp3', '.flac', '.ogg']
|
| 207 |
+
if path.suffix.lower() not in valid_extensions:
|
| 208 |
+
return False
|
| 209 |
+
|
| 210 |
+
return True
|
| 211 |
+
|
| 212 |
+
def generate_composition_filename(prefix: str = "composition") -> str:
|
| 213 |
+
"""
|
| 214 |
+
Generate a unique filename for composition exports.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
prefix: Filename prefix
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
str: Unique filename with timestamp
|
| 221 |
+
"""
|
| 222 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 223 |
+
return f"{prefix}_{timestamp}.wav"
|
| 224 |
+
|
| 225 |
+
# Initialize logger when module is imported
|
| 226 |
+
logger = setup_logging()
|