Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import librosa | |
| import tensorflow as tf | |
| from scipy.fftpack import dct | |
| import os | |
| import tempfile | |
| import shutil | |
| import subprocess | |
| import re | |
| import requests | |
| from io import BytesIO | |
| # DSCNN model configuration | |
| MODEL_PATH = "ds_cnn_l_quantized.tflite" | |
| DEFAULT_CONFIG = "u55_eval_with_TA_config_400_and_200_MHz.ini" | |
| # Keywords based on Speech Commands dataset (12 classes) | |
| KEYWORDS = [ | |
| "silence", "unknown", "yes", "no", "up", "down", | |
| "left", "right", "on", "off", "stop", "go" | |
| ] | |
| print("Loading DSCNN TensorFlow Lite model...") | |
| try: | |
| # Load the TFLite model | |
| interpreter = tf.lite.Interpreter(model_path=MODEL_PATH) | |
| interpreter.allocate_tensors() | |
| # Get input and output details | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| print(f"β DSCNN model loaded successfully!") | |
| print(f"Input shape: {input_details[0]['shape']}") | |
| print(f"Output shape: {output_details[0]['shape']}") | |
| print(f"Input dtype: {input_details[0]['dtype']}") | |
| print(f"Output dtype: {output_details[0]['dtype']}") | |
| except Exception as e: | |
| print(f"β Error loading DSCNN model: {e}") | |
| interpreter = None | |
| # Vela config file is copied from SR app | |
| def extract_summary_from_log(log_text): | |
| summary_keys = [ | |
| "Accelerator configuration", | |
| "Accelerator clock", | |
| "Total SRAM used", | |
| "Total On-chip Flash used", | |
| "CPU operators", | |
| "NPU operators", | |
| "Batch Inference time" | |
| ] | |
| summary = [] | |
| for key in summary_keys: | |
| match = re.search(rf"{re.escape(key)}\s+(.+)", log_text) | |
| if match: | |
| value = match.group(1).strip() | |
| if key == "Batch Inference time": | |
| value = value.split(",")[0].strip() | |
| key = "Inference time" | |
| summary.append((key, value)) | |
| return summary | |
| def run_vela(model_file): | |
| accel = "ethos-u55-128" | |
| optimise = "Size" | |
| mem_mode = "Sram_Only" | |
| sys_config = "Ethos_U55_400MHz_SRAM_3.2_GBs_Flash_0.05_GBs" | |
| tmpdir = tempfile.mkdtemp() | |
| try: | |
| # Use the original uploaded model filename | |
| original_model_name = os.path.basename(model_file) | |
| model_path = os.path.join(tmpdir, original_model_name) | |
| shutil.copy(model_file, model_path) | |
| config_path = os.path.join(tmpdir, DEFAULT_CONFIG) | |
| shutil.copy(DEFAULT_CONFIG, config_path) | |
| output_dir = os.path.join(tmpdir, "vela_out") | |
| os.makedirs(output_dir, exist_ok=True) | |
| cmd = [ | |
| "vela", | |
| f"--accelerator-config={accel}", | |
| f"--optimise={optimise}", | |
| f"--config={config_path}", | |
| f"--memory-mode={mem_mode}", | |
| f"--system-config={sys_config}", | |
| model_path, | |
| "--verbose-cycle-estimate", | |
| "--verbose-performance", | |
| f"--output-dir={output_dir}" | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True, check=True) | |
| vela_stdout = result.stdout | |
| # Check for unsupported model patterns in logs | |
| unsupported_patterns = [ | |
| "Warning: Unsupported TensorFlow Lite semantics", | |
| "Network Tops/s nan Tops/s", | |
| "Neural network macs 0 MACs/batch" | |
| ] | |
| if any(pat in vela_stdout for pat in unsupported_patterns): | |
| summary_html = ( | |
| "<div class='sr110-results' style='background:#fff3f3;border-radius:14px;padding:24px 18px 18px 18px;" | |
| "max-width:430px;min-width:320px;width:100%;margin:auto;color:#d32f2f;font-family:sans-serif;" | |
| "font-size:1.1em;text-align:left;font-weight:600;'>" | |
| "This model has unsupported layers and needs investigation based on layers.<br>" | |
| "Please use Vela compiler on your Host Machine for further analysis." | |
| "</div>" | |
| ) | |
| # Try to provide per-layer.csv if available for download | |
| per_layer_csv = None | |
| for log_fname in os.listdir(output_dir): | |
| if log_fname.endswith("per-layer.csv"): | |
| per_layer_csv = os.path.join("/tmp", log_fname) | |
| shutil.copy(os.path.join(output_dir, log_fname), per_layer_csv) | |
| break | |
| return summary_html, None, per_layer_csv | |
| model_filename = os.path.basename(model_file) | |
| if model_filename: | |
| vela_stdout = vela_stdout.replace( | |
| "Network summary for", | |
| f"Network summary for {model_filename} (" | |
| ) | |
| summary_items = extract_summary_from_log(vela_stdout) | |
| # Convert summary_items to dict for easy access | |
| summary_dict = dict(summary_items) if summary_items else {} | |
| # Build 4 cards for results | |
| def clean_ops(val): | |
| # Remove '=' and leading/trailing spaces | |
| return val.lstrip("= ").strip() if isinstance(val, str) else val | |
| summary_html = ( | |
| "<div class='sr110-results' style='background:#1e1e2f;border-radius:18px;padding:18px 18px 12px 18px;" | |
| "max-width:430px;min-width:320px;width:100%;margin:auto;color:#eee;font-family:sans-serif;'>" | |
| "<h3 class='sr110-title' style='margin-top:0;margin-bottom:12px;font-size:1.35em;color:#00b0ff;text-align:left;'>Estimated Results on SR110</h3>" | |
| "<div style='display:flex;flex-wrap:wrap;gap:10px;justify-content:center;'>" | |
| # Card 1: Accelerator | |
| "<div class='sr110-card' style='flex:1 1 170px;min-width:170px;max-width:180px;background:#23233a;border-radius:12px;padding:10px 10px 8px 10px;'>" | |
| "<div class='sr110-title' style='font-size:1em;font-weight:520;margin-bottom:6px;color:#00b0ff;'>π Accelerator</div>" | |
| f"<div style='margin-bottom:2px;'><span class='sr110-label' style='color:#ccc;'>Configuration:</span> <span class='sr110-value' style='color:#fff;font-weight:500'>{summary_dict.get('Accelerator configuration','-')}</span></div>" | |
| f"<div><span class='sr110-label' style='color:#ccc;'>Accelerator clock:</span> <span class='sr110-value' style='color:#fff;font-weight:500'>{summary_dict.get('Accelerator clock','-')}</span></div>" | |
| "</div>" | |
| # Card 2: Memory Usage | |
| "<div class='sr110-card' style='flex:1 1 170px;min-width:170px;max-width:180px;background:#23233a;border-radius:12px;padding:10px 10px 8px 10px;'>" | |
| "<div class='sr110-title' style='font-size:1em;font-weight:520;margin-bottom:6px;color:#00b0ff;'>πΎ Memory Usage</div>" | |
| f"<div style='margin-bottom:2px;'><span class='sr110-label' style='color:#ccc;'>Total SRAM:</span> <span class='sr110-value' style='color:#fff;font-weight:500'>{summary_dict.get('Total SRAM used','-')}</span></div>" | |
| f"<div><span class='sr110-label' style='color:#ccc;'>Total On-chip Flash:</span> <span class='sr110-value' style='color:#fff;font-weight:500'>{summary_dict.get('Total On-chip Flash used','-')}</span></div>" | |
| "</div>" | |
| # Card 3: Operator Distribution | |
| "<div class='sr110-card' style='flex:1 1 170px;min-width:170px;max-width:180px;background:#23233a;border-radius:12px;padding:10px 10px 8px 10px;'>" | |
| "<div class='sr110-title' style='font-size:1em;font-weight:520;margin-bottom:6px;color:#00b0ff;'>π Operator Distribution</div>" | |
| f"<div style='margin-bottom:2px;'><span class='sr110-label' style='color:#ccc;'>CPU Operators:</span> <span class='sr110-value' style='color:#fff;font-weight:500'>{clean_ops(summary_dict.get('CPU operators','-'))}</span></div>" | |
| f"<div><span class='sr110-label' style='color:#ccc;'>NPU Operators:</span> <span class='sr110-value' style='color:#fff;font-weight:500'>{clean_ops(summary_dict.get('NPU operators','-'))}</span></div>" | |
| "</div>" | |
| # Card 4: Performance | |
| "<div class='sr110-card' style='flex:1 1 170px;min-width:170px;max-width:180px;background:#23233a;border-radius:12px;padding:10px 10px 8px 10px;'>" | |
| "<div class='sr110-title' style='font-size:1em;font-weight:520;margin-bottom:6px;color:#00b0ff;'>β‘ Performance</div>" | |
| f"<div><span class='sr110-label' style='color:#ccc;'>Inference time:</span> <span class='sr110-value' style='color:#fff;font-weight:500'>{summary_dict.get('Inference time','-')}</span></div>" | |
| "</div>" | |
| "</div></div>" | |
| ) if summary_items else "<div style='color:red'>Summary info not found in log.</div>" | |
| for fname in os.listdir(output_dir): | |
| if fname.endswith("vela.tflite"): | |
| final_path = os.path.join("/tmp", fname) | |
| shutil.copy(os.path.join(output_dir, fname), final_path) | |
| # Find per-layer.csv file for logs | |
| per_layer_csv = None | |
| for log_fname in os.listdir(output_dir): | |
| if log_fname.endswith("per-layer.csv"): | |
| per_layer_csv = os.path.join("/tmp", log_fname) | |
| shutil.copy(os.path.join(output_dir, log_fname), per_layer_csv) | |
| break | |
| return summary_html, final_path, per_layer_csv | |
| # If no tflite, still try to return per-layer.csv if present | |
| per_layer_csv = None | |
| for log_fname in os.listdir(output_dir): | |
| if log_fname.endswith("per-layer.csv"): | |
| per_layer_csv = os.path.join("/tmp", log_fname) | |
| shutil.copy(os.path.join(output_dir, log_fname), per_layer_csv) | |
| break | |
| return summary_html, None, per_layer_csv | |
| finally: | |
| shutil.rmtree(tmpdir) | |
| # Run Vela analysis on startup and cache results | |
| print("Running Vela analysis on DSCNN model...") | |
| try: | |
| vela_html, compiled_model, per_layer_csv = run_vela(MODEL_PATH) | |
| except Exception as e: | |
| vela_html = f"<div style='color:red'>Vela analysis failed: {str(e)}</div>" | |
| def extract_mfcc_features(audio_path, target_length=490): | |
| """ | |
| Extract MFCC features exactly as specified in the original DSCNN paper. | |
| Based on "Hello Edge: Keyword Spotting on Microcontrollers" | |
| Parameters from paper: | |
| - 40ms frame length (640 samples at 16kHz) | |
| - 20ms stride (320 samples at 16kHz) | |
| - 10 MFCC features per frame | |
| - 49 frames total for 1 second β 49Γ10 = 490 features | |
| """ | |
| try: | |
| # Load audio and resample to 16kHz (standard for speech commands) | |
| audio, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| # Ensure audio is exactly 1 second (16000 samples) | |
| if len(audio) < 16000: | |
| # Pad with zeros | |
| audio = np.pad(audio, (0, 16000 - len(audio)), 'constant') | |
| else: | |
| # Truncate to 1 second | |
| audio = audio[:16000] | |
| # DSCNN paper parameters | |
| frame_length = 640 # 40ms at 16kHz | |
| hop_length = 320 # 20ms at 16kHz (50% overlap) | |
| n_mfcc = 10 # 10 MFCC features as in paper | |
| n_fft = 1024 # FFT size | |
| n_mels = 40 # Mel filter bank size (before DCT) | |
| # Extract mel spectrogram | |
| mel_spec = librosa.feature.melspectrogram( | |
| y=audio, | |
| sr=sr, | |
| n_fft=n_fft, | |
| hop_length=hop_length, | |
| win_length=frame_length, | |
| n_mels=n_mels, | |
| fmin=20, | |
| fmax=4000 | |
| ) | |
| # Convert to log scale | |
| log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) | |
| # Apply DCT to get MFCC features (only take first 10 coefficients) | |
| mfcc_features = dct(log_mel_spec, axis=0, norm='ortho')[:n_mfcc, :] | |
| # Should be shape (10, 49) for 1 second of audio | |
| print(f"MFCC shape before flattening: {mfcc_features.shape}") | |
| # Flatten to 1D array (10 Γ 49 = 490 features) | |
| features_flat = mfcc_features.flatten() | |
| # Ensure exactly 490 features | |
| if len(features_flat) > target_length: | |
| features_flat = features_flat[:target_length] | |
| elif len(features_flat) < target_length: | |
| features_flat = np.pad(features_flat, (0, target_length - len(features_flat)), 'constant') | |
| print(f"Features length after processing: {len(features_flat)}") | |
| # Normalize features (zero mean, unit variance) | |
| features_flat = (features_flat - np.mean(features_flat)) / (np.std(features_flat) + 1e-8) | |
| # Quantize to INT8 range for DSCNN model | |
| # Scale to approximately match training distribution | |
| features_int8 = np.clip(features_flat * 127.0, -128, 127).astype(np.int8) | |
| return features_int8.reshape(1, -1) # Shape: (1, 490) | |
| except Exception as e: | |
| raise Exception(f"Error extracting MFCC features: {str(e)}") | |
| def classify_audio(audio_input): | |
| """ | |
| Classify the input audio using the DSCNN model and return keyword predictions. | |
| """ | |
| if audio_input is None: | |
| return "Please upload an audio file or record audio." | |
| if interpreter is None: | |
| return "β DSCNN model not loaded. Please refresh the page and try again." | |
| try: | |
| # Extract MFCC features | |
| features = extract_mfcc_features(audio_input) | |
| print(f"Input features shape: {features.shape}") | |
| print(f"Input features dtype: {features.dtype}") | |
| print(f"Input features range: [{features.min()}, {features.max()}]") | |
| # Set input tensor | |
| interpreter.set_tensor(input_details[0]['index'], features) | |
| # Run inference | |
| interpreter.invoke() | |
| # Get output | |
| output_data = interpreter.get_tensor(output_details[0]['index']) | |
| print(f"Raw output shape: {output_data.shape}") | |
| print(f"Raw output dtype: {output_data.dtype}") | |
| print(f"Raw output range: [{output_data.min()}, {output_data.max()}]") | |
| # Handle quantized INT8 output | |
| if output_data.dtype == np.int8: | |
| # Dequantize INT8 to float (assuming symmetric quantization) | |
| # Scale factor is typically around 1/128 for INT8 | |
| logits = output_data.astype(np.float32) / 128.0 | |
| else: | |
| logits = output_data.astype(np.float32) | |
| # Apply softmax to get probabilities | |
| exp_logits = np.exp(logits - np.max(logits)) | |
| probabilities = exp_logits / np.sum(exp_logits) | |
| # Get predictions with confidence scores | |
| predictions = [] | |
| for i, prob in enumerate(probabilities[0]): | |
| predictions.append({ | |
| 'label': KEYWORDS[i], | |
| 'score': float(prob) | |
| }) | |
| # Sort by confidence score | |
| predictions = sorted(predictions, key=lambda x: x['score'], reverse=True) | |
| # Format results | |
| results = [] | |
| for i, pred in enumerate(predictions[:5]): | |
| confidence = pred['score'] * 100 | |
| label = pred['label'] | |
| indicator = "π―" if i == 0 else " " | |
| results.append(f"{indicator} {i+1}. **{label}**: {confidence:.1f}%") | |
| return "\n".join(results) | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "mfcc" in error_msg.lower() or "librosa" in error_msg.lower(): | |
| return "β Audio processing error. Please ensure your audio file is in a supported format (WAV, MP3, etc.)" | |
| elif "model" in error_msg.lower() or "tensor" in error_msg.lower(): | |
| return "β Model inference error. Please try recording a clear 1-second audio clip." | |
| else: | |
| return f"β Error processing audio: {error_msg}\n\nTip: Try recording a clear 1-second word like 'yes' or 'stop'." | |
| def load_example_audio(example_name): | |
| """Load example audio for demonstration.""" | |
| # This would load pre-recorded examples if available | |
| return None | |
| def compile_uploaded_model(model_file): | |
| """Compile uploaded model with Vela and return results""" | |
| if model_file is None: | |
| error_html = ( | |
| "<div class='sr110-results' style='background:#fff3f3;border-radius:14px;padding:24px 18px 18px 18px;" | |
| "max-width:430px;min-width:320px;width:100%;margin:auto;color:#d32f2f;font-family:sans-serif;" | |
| "font-size:1.1em;text-align:center;font-weight:600;'>" | |
| "No model file uploaded." | |
| "</div>" | |
| ) | |
| return ( | |
| error_html, | |
| gr.update(visible=False, value=None), | |
| gr.update(visible=False, value=None) | |
| ) | |
| try: | |
| # Run Vela analysis on uploaded model | |
| results_html, compiled_model_path, per_layer_csv = run_vela(model_file) | |
| return ( | |
| results_html, | |
| gr.update(visible=compiled_model_path is not None, value=compiled_model_path), | |
| gr.update(visible=per_layer_csv is not None, value=per_layer_csv) | |
| ) | |
| except Exception as e: | |
| error_html = ( | |
| "<div class='sr110-results' style='background:#fff3f3;border-radius:14px;padding:24px 18px 18px 18px;" | |
| "max-width:430px;min-width:320px;width:100%;margin:auto;color:#d32f2f;font-family:sans-serif;" | |
| "font-size:1.1em;text-align:center;font-weight:600;'>" | |
| f"Vela compilation failed: {str(e)}" | |
| "</div>" | |
| ) | |
| return ( | |
| error_html, | |
| gr.update(visible=False, value=None), | |
| gr.update(visible=False, value=None) | |
| ) | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| theme=gr.themes.Default(), | |
| title="DSCNN Wake Word Detection", | |
| css=""" | |
| body { | |
| background: #fafafa !important; | |
| } | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin-left: auto !important; | |
| margin-right: auto !important; | |
| background-color: #fafafa !important; | |
| font-family: 'Inter', 'Segoe UI', -apple-system, sans-serif !important; | |
| } | |
| .gr-row { | |
| display: flex !important; | |
| justify-content: center !important; | |
| align-items: flex-start !important; | |
| gap: 48px !important; | |
| } | |
| .gr-column { | |
| align-items: flex-start !important; | |
| justify-content: flex-start !important; | |
| } | |
| .fixed-upload-box { | |
| width: 100% !important; | |
| max-width: 420px !important; | |
| margin-bottom: 18px !important; | |
| } | |
| .download-btn-custom, .compile-btn-custom { | |
| width: 100% !important; | |
| margin-bottom: 18px !important; | |
| } | |
| .upload-file-box .w-full, .download-file-box .w-full { | |
| height: 120px !important; | |
| background: #232b36 !important; | |
| border-radius: 8px !important; | |
| color: #fff !important; | |
| font-weight: 600 !important; | |
| font-size: 1.1em !important; | |
| box-shadow: none !important; | |
| display: flex !important; | |
| align-items: center !important; | |
| justify-content: center !important; | |
| } | |
| .upload-file-box .w-full .file-preview { | |
| margin: 0 auto !important; | |
| text-align: center !important; | |
| width: 100%; | |
| } | |
| #run-vela-btn, .compile-btn, .compile-btn-custom { | |
| background-color: #007dc3 !important; | |
| color: white !important; | |
| font-size: 1.1em; | |
| border-radius: 8px; | |
| margin-top: 12px; | |
| margin-bottom: 18px; | |
| text-align: center; | |
| height: 40px !important; | |
| } | |
| .results-summary-box, #results-summary { | |
| margin-left: 0 !important; | |
| } | |
| h1, h3, .gr-markdown h1, .gr-markdown h3 { color: #1976d2 !important; } | |
| p, .gr-markdown p, .gr-markdown span, .gr-markdown { color: #222 !important; } | |
| .custom-footer { | |
| display: block !important; | |
| margin: 40px auto 0 auto !important; | |
| max-width: 600px !important; | |
| width: 100% !important; | |
| background: #e6f4ff !important; | |
| border-radius: 10px !important; | |
| box-shadow: 0 2px 2px #0001 !important; | |
| padding: 24px 32px 24px 32px !important; | |
| font-size: 1.13em !important; | |
| color: #0a2540 !important; | |
| font-family: sans-serif !important; | |
| text-align: center !important; | |
| position: relative !important; | |
| z-index: 10 !important; | |
| } | |
| .custom-footer a { | |
| color: #0074d9 !important; | |
| text-decoration: underline !important; | |
| font-weight: 700 !important; | |
| } | |
| .card { | |
| background: white !important; | |
| border-radius: 12px !important; | |
| box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06) !important; | |
| border: 1px solid #e5e7eb !important; | |
| margin-bottom: 1.5rem !important; | |
| transition: all 0.2s ease-in-out !important; | |
| overflow: hidden !important; | |
| } | |
| .card > * { | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| } | |
| .card:hover { | |
| box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05) !important; | |
| transform: translateY(-1px) !important; | |
| } | |
| .card-header { | |
| background: linear-gradient(135deg, #1975cf 0%, #1557b0 100%) !important; | |
| color: white !important; | |
| padding: 1rem 1.5rem !important; | |
| border-radius: 12px 12px 0 0 !important; | |
| font-weight: 600 !important; | |
| font-size: 1.1rem !important; | |
| } | |
| .card-header, | |
| div.card-header, | |
| div.card-header span, | |
| div.card-header * { | |
| color: white !important; | |
| } | |
| .card-content { | |
| padding: 1.5rem !important; | |
| color: #4b5563 !important; | |
| line-height: 1.6 !important; | |
| background: white !important; | |
| } | |
| .stats-grid { | |
| display: grid !important; | |
| grid-template-columns: 1fr 1fr !important; | |
| gap: 1.5rem !important; | |
| margin-top: 1.5rem !important; | |
| } | |
| .stat-item { | |
| background: #f8fafc !important; | |
| padding: 1rem !important; | |
| border-radius: 8px !important; | |
| border-left: 4px solid #1975cf !important; | |
| } | |
| .stat-label { | |
| font-weight: 600 !important; | |
| color: #4b5563 !important; | |
| font-size: 0.9rem !important; | |
| margin-bottom: 0.5rem !important; | |
| } | |
| .stat-value { | |
| color: #4b5563 !important; | |
| font-size: 0.85rem !important; | |
| } | |
| .btn-example { | |
| background: #f1f5f9 !important; | |
| border: 1px solid #cbd5e1 !important; | |
| color: #4b5563 !important; | |
| border-radius: 6px !important; | |
| transition: all 0.2s ease !important; | |
| margin: 0.35rem !important; | |
| padding: 0.5rem 1rem !important; | |
| } | |
| .btn-example:hover { | |
| background: #1975cf !important; | |
| border-color: #1975cf !important; | |
| color: white !important; | |
| } | |
| .btn-primary { | |
| background: #1975cf !important; | |
| border-color: #1975cf !important; | |
| color: white !important; | |
| } | |
| .btn-primary:hover { | |
| background: #1557b0 !important; | |
| border-color: #1557b0 !important; | |
| } | |
| .markdown { | |
| color: #374151 !important; | |
| } | |
| .results-text { | |
| color: #4b5563 !important; | |
| font-weight: 500 !important; | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| } | |
| .results-text p { | |
| color: #4b5563 !important; | |
| margin: 0.5rem 0 !important; | |
| } | |
| .results-text * { | |
| color: #4b5563 !important; | |
| } | |
| div[data-testid="markdown"] p { | |
| color: #4b5563 !important; | |
| } | |
| .prose { | |
| color: #4b5563 !important; | |
| } | |
| .prose * { | |
| color: #4b5563 !important; | |
| } | |
| .card-header, | |
| .card-header * { | |
| color: white !important; | |
| } | |
| /* Override grey colors for SR110 Vela results section - MUST be after prose rules */ | |
| .prose .sr110-results, | |
| .prose .sr110-results *, | |
| .prose .sr110-results h3, | |
| .prose .sr110-results div, | |
| .prose .sr110-results span, | |
| .sr110-results, | |
| .sr110-results *, | |
| .sr110-results h3, | |
| .sr110-results div, | |
| .sr110-results span { | |
| color: inherit !important; | |
| } | |
| /* Preserve original colors for dark theme cards with higher specificity */ | |
| .prose .sr110-results .sr110-card, | |
| .sr110-results .sr110-card { | |
| background: #23233a !important; | |
| } | |
| .prose .sr110-results .sr110-title, | |
| .sr110-results .sr110-title { | |
| color: #00b0ff !important; | |
| } | |
| .prose .sr110-results .sr110-label, | |
| .sr110-results .sr110-label { | |
| color: #ccc !important; | |
| } | |
| .prose .sr110-results .sr110-value, | |
| .sr110-results .sr110-value { | |
| color: #fff !important; | |
| } | |
| """ | |
| ) as demo: | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>π€ DSCNN Wake Word Detection</h1> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_audio = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="filepath", | |
| label="Record or Upload Audio", | |
| value=None | |
| ) | |
| classify_btn = gr.Button( | |
| "Detect Wake Word", | |
| variant="primary", | |
| size="lg", | |
| elem_classes=["btn-primary"] | |
| ) | |
| with gr.Group(elem_classes=["card"]): | |
| gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Supported Keywords</span></div>') | |
| with gr.Column(elem_classes=["card-content"]): | |
| gr.HTML(""" | |
| <div style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 0.5rem; text-align: center;"> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">yes</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">no</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">up</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">down</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">left</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">right</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">on</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">off</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">stop</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">go</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">silence</div> | |
| <div style="padding: 0.5rem; background: #f8fafc; border-radius: 6px; font-weight: 500;">unknown</div> | |
| </div> | |
| """) | |
| # Add Model Upload Section | |
| with gr.Group(elem_classes=["card"]): | |
| gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Upload Your TFLite Model</span></div>') | |
| with gr.Column(elem_classes=["card-content"]): | |
| model_file = gr.File( | |
| file_types=[".tflite"], | |
| label="Upload .tflite model for Vela analysis", | |
| interactive=True | |
| ) | |
| compile_btn = gr.Button( | |
| "Compile Model for SR110", | |
| variant="primary", | |
| size="lg", | |
| elem_classes=["btn-primary"] | |
| ) | |
| download_model_btn = gr.DownloadButton( | |
| label="Download Compiled Model", | |
| visible=False, | |
| elem_classes=["btn-primary"] | |
| ) | |
| download_logs_btn = gr.DownloadButton( | |
| label="Download Per Layer Logs", | |
| visible=False, | |
| elem_classes=["btn-primary"] | |
| ) | |
| with gr.Column(scale=1): | |
| # Display Vela analysis results (dynamic) | |
| vela_results_html = gr.HTML(vela_html) | |
| with gr.Group(elem_classes=["card"]): | |
| gr.HTML('<div class="card-header"><span style="color: white; font-weight: 600;">Wake Word Detection Results</span></div>') | |
| with gr.Column(elem_classes=["card-content"]): | |
| output_text = gr.Markdown( | |
| value="Record or upload audio to see wake word predictions...", | |
| label="", | |
| elem_classes=["results-text"] | |
| ) | |
| # Set up event handlers | |
| classify_btn.click( | |
| fn=classify_audio, | |
| inputs=input_audio, | |
| outputs=output_text | |
| ) | |
| # Model compile handler | |
| compile_btn.click( | |
| fn=compile_uploaded_model, | |
| inputs=[model_file], | |
| outputs=[vela_results_html, download_model_btn, download_logs_btn] | |
| ) | |
| # Auto-classify when audio is uploaded | |
| input_audio.change( | |
| fn=classify_audio, | |
| inputs=input_audio, | |
| outputs=output_text | |
| ) | |
| # Launch the demo | |
| if __name__ == "__main__": | |
| demo.launch() |