Spaces:
Running
Running
| import os | |
| import json | |
| import time | |
| import pickle | |
| import random | |
| import importlib | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Dict, List, Tuple, Optional | |
| import gradio as gr | |
| import pandas as pd | |
| import requests | |
| import yaml | |
| # ---------------- Config & Paths ---------------- | |
| ROOT = Path(__file__).parent | |
| STATE_DIR = ROOT / "state"; STATE_DIR.mkdir(exist_ok=True, parents=True) | |
| LOG_DIR = STATE_DIR | |
| ELO_PATH = STATE_DIR / "elo.pkl" | |
| LEADERBOARD_CSV = STATE_DIR / "leaderboard.csv" | |
| VOTES_LOG = LOG_DIR / "votes.jsonl" | |
| CACHE_PATH = STATE_DIR / "cache.pkl" # (model, song) -> [items] | |
| INTERACTIONS_LOG = STATE_DIR / "interactions.jsonl" | |
| MODELS_YAML = ROOT / "models.yaml" | |
| TRACKS_CSV = ROOT / "tracks.csv" | |
| TOPK_SHOW = 10 | |
| K_FACTOR = 16 | |
| START_ELO = 1200.0 | |
| SEED = 343 | |
| random.seed(SEED) | |
| # ---------------- Model loading ---------------- | |
| def load_models(): | |
| if not MODELS_YAML.exists(): | |
| raise RuntimeError(f"models.yaml not found at {MODELS_YAML}") | |
| cfg = yaml.safe_load(MODELS_YAML.read_text()) | |
| models = cfg.get("models", []) | |
| if not models: | |
| raise RuntimeError("No models configured in models.yaml") | |
| names = [m["name"] for m in models] | |
| if len(names) != len(set(names)): | |
| raise RuntimeError("Duplicate model names in models.yaml") | |
| return {m["name"]: m for m in models} | |
| MODELS = load_models() | |
| # ---------------- Track Validation ---------------- | |
| def load_tracks(): | |
| """Load track names and IDs from tracks.csv for validation and Spotify integration""" | |
| if not TRACKS_CSV.exists(): | |
| print(f"Warning: {TRACKS_CSV} not found. Track validation disabled.") | |
| return set(), {} | |
| try: | |
| df = pd.read_csv(TRACKS_CSV) | |
| # Create track names in format "Track Name by Artist Name" | |
| track_names = [] | |
| track_id_map = {} # Maps formatted track names to Spotify track IDs | |
| for _, row in df.iterrows(): | |
| track_name = row['track_name'].strip() | |
| artist_name = row['primary_artist_name'].strip() | |
| track_id = row['track_id'].strip() | |
| if track_name and artist_name and track_id: | |
| formatted_name = f"{track_name} by {artist_name}" | |
| track_names.append(formatted_name.lower()) | |
| track_id_map[formatted_name.lower()] = track_id | |
| track_names_set = set(track_names) | |
| print(f"Loaded {len(track_names_set)} track names for validation") | |
| print(f"Sample track IDs: {list(track_id_map.items())[:3]}") # Debug print | |
| return track_names_set, track_id_map | |
| except Exception as e: | |
| print(f"Error loading tracks.csv: {e}. Track validation disabled.") | |
| return set(), {} | |
| def validate_track_name(track_name: str, valid_tracks: set) -> Tuple[bool, str]: | |
| """ | |
| Check if a track name exists in the tracks database. | |
| Args: | |
| track_name: The track name to validate | |
| valid_tracks: Set of valid track names (lowercase) | |
| Returns: | |
| Tuple of (is_valid, message) | |
| """ | |
| if not track_name or not track_name.strip(): | |
| return False, "Empty track name" | |
| track_lower = track_name.lower().strip() | |
| # Direct match | |
| if track_lower in valid_tracks: | |
| return True, "Track found" | |
| # Fuzzy matching - check if any valid track contains this name | |
| matching_tracks = [t for t in valid_tracks if track_lower in t or t in track_lower] | |
| if matching_tracks: | |
| return True, f"Similar track found: {matching_tracks[0]}" | |
| return False, "Track not found in database" | |
| # Load valid tracks and track ID mapping | |
| VALID_TRACKS, TRACK_ID_MAP = load_tracks() | |
| def get_spotify_track_id(track_name: str) -> Optional[str]: | |
| """ | |
| Get Spotify track ID for a given track name. | |
| Args: | |
| track_name: Track name in format "Song by Artist" | |
| Returns: | |
| Spotify track ID or None if not found | |
| """ | |
| if not track_name: | |
| return None | |
| track_lower = track_name.lower().strip() | |
| # Direct match first | |
| if track_lower in TRACK_ID_MAP: | |
| return TRACK_ID_MAP[track_lower] | |
| # Try to find partial matches | |
| for stored_track, track_id in TRACK_ID_MAP.items(): | |
| if track_lower in stored_track or stored_track in track_lower: | |
| return track_id | |
| return None | |
| def create_spotify_url(track_id: str) -> str: | |
| """ | |
| Create Spotify URL for a track. | |
| Args: | |
| track_id: Spotify track ID | |
| Returns: | |
| Spotify URL | |
| """ | |
| return f"https://open.spotify.com/track/{track_id}" | |
| def create_spotify_player_html(track_id: str, width: str = "100%", height: str = "152") -> str: | |
| """ | |
| Create HTML for Spotify web player embed. | |
| Args: | |
| track_id: Spotify track ID | |
| width: Player width (default: "100%") | |
| height: Player height (default: "152") | |
| Returns: | |
| HTML string for Spotify player | |
| """ | |
| if not track_id: | |
| return "<p>No preview available</p>" | |
| url = f"https://open.spotify.com/embed/track/{track_id}?utm_source=generator" | |
| return f''' | |
| <iframe style="border-radius:12px" | |
| src="{url}" | |
| width="{width}" | |
| height="{height}" | |
| frameBorder="0" | |
| allowfullscreen="" | |
| allow="autoplay; clipboard-write; encrypted-media; fullscreen; picture-in-picture" | |
| loading="lazy"> | |
| </iframe> | |
| ''' | |
| def get_spotify_player(track_name: str) -> str: | |
| """ | |
| Get Spotify web player for a track. | |
| Args: | |
| track_name: Track name to get player for | |
| Returns: | |
| Spotify player HTML or error message | |
| """ | |
| if not track_name or not track_name.strip(): | |
| return "<p>Please enter a track name first</p>" | |
| print(f"Looking for track: '{track_name}'") # Debug print | |
| print(f"Available tracks sample: {list(VALID_TRACKS)[:5]}") # Debug print | |
| track_id = get_spotify_track_id(track_name) | |
| print(f"Found track ID: {track_id}") # Debug print | |
| if track_id: | |
| player_html = create_spotify_player_html(track_id) | |
| return f"<h3>🎵 Now Playing: {track_name}</h3>{player_html}" | |
| else: | |
| return f"<p>❌ No preview available for: {track_name}</p><p>Make sure the track exists in our database.</p><p>Available tracks: {', '.join(list(VALID_TRACKS)[:3])}</p>" | |
| def check_track_in_database(track_name: str) -> str: | |
| """ | |
| Check if a track name exists in the tracks database. | |
| This function can be called directly to validate track names. | |
| Args: | |
| track_name: The track name to check | |
| Returns: | |
| String message indicating validation result | |
| """ | |
| is_valid, message = validate_track_name(track_name, VALID_TRACKS) | |
| return message | |
| def find_matching_tracks(query: str, max_results: int = 5) -> List[str]: | |
| """ | |
| Find tracks that match the given query string. | |
| Args: | |
| query: The search query | |
| max_results: Maximum number of results to return | |
| Returns: | |
| List of matching track names with artists | |
| """ | |
| if not query or not query.strip(): | |
| return [] | |
| query_lower = query.lower().strip() | |
| matches = [] | |
| # Direct matches first | |
| for track in VALID_TRACKS: | |
| if track == query_lower: | |
| matches.append(track.title()) | |
| if len(matches) >= max_results: | |
| return matches | |
| # Partial matches | |
| for track in VALID_TRACKS: | |
| if query_lower in track and track not in matches: | |
| matches.append(track.title()) | |
| if len(matches) >= max_results: | |
| return matches | |
| # Fuzzy matches (track contains query or query contains track) | |
| for track in VALID_TRACKS: | |
| if (track in query_lower or query_lower in track) and track not in matches: | |
| matches.append(track.title()) | |
| if len(matches) >= max_results: | |
| return matches | |
| return matches[:max_results] | |
| def get_random_track() -> str: | |
| """ | |
| Get a random track from the database. | |
| Returns: | |
| Random track name with artist (title case) | |
| """ | |
| if not VALID_TRACKS: | |
| return "No tracks available" | |
| random_track = random.choice(list(VALID_TRACKS)) | |
| return random_track.title() | |
| # ---------------- Cache ---------------- | |
| def load_cache() -> Dict[Tuple[str, str], List[str]]: | |
| if CACHE_PATH.exists(): | |
| with CACHE_PATH.open("rb") as f: | |
| return pickle.load(f) | |
| return {} | |
| def save_cache(cache: Dict[Tuple[str, str], List[str]]): | |
| with CACHE_PATH.open("wb") as f: | |
| pickle.dump(cache, f) | |
| CACHE = load_cache() | |
| # ---------------- Elo ---------------- | |
| def expected_score(ra: float, rb: float) -> float: | |
| return 1.0 / (1.0 + 10 ** ((rb - ra) / 400.0)) | |
| def update_elo(elo: Dict[str, float], a: str, b: str, outcome: str) -> None: | |
| ra = elo.get(a, START_ELO) | |
| rb = elo.get(b, START_ELO) | |
| ea = expected_score(ra, rb) | |
| eb = 1.0 - ea | |
| if outcome == "A": | |
| sa, sb = 1.0, 0.0 | |
| elif outcome == "B": | |
| sa, sb = 0.0, 1.0 | |
| else: | |
| sa, sb = 0.5, 0.5 | |
| elo[a] = ra + K_FACTOR * (sa - ea) | |
| elo[b] = rb + K_FACTOR * (sb - eb) | |
| def load_elo() -> Dict[str, float]: | |
| if ELO_PATH.exists(): | |
| with ELO_PATH.open("rb") as f: | |
| elo = pickle.load(f) | |
| else: | |
| elo = {} | |
| # ensure every configured model has an Elo | |
| for m in MODELS.keys(): | |
| elo.setdefault(m, START_ELO) | |
| return elo | |
| def save_elo(elo: Dict[str, float]): | |
| with ELO_PATH.open("wb") as f: | |
| pickle.dump(elo, f) | |
| def leaderboard_df(elo: Dict[str, float]) -> pd.DataFrame: | |
| df = pd.DataFrame({"model": list(elo.keys()), "elo": list(elo.values())}) | |
| df = df.sort_values("elo", ascending=False) | |
| df.to_csv(LEADERBOARD_CSV, index=False) | |
| return df | |
| # ---------------- Backends ---------------- | |
| def call_http(model_cfg: dict, song_ratings: List[Dict[str, any]]) -> List[str]: | |
| endpoint = model_cfg["endpoint"] | |
| timeout = float(model_cfg.get("timeout", 8)) | |
| r = requests.post( | |
| endpoint, | |
| json={"song_ratings": song_ratings}, | |
| timeout=timeout, | |
| headers={"Content-Type": "application/json"}, | |
| ) | |
| r.raise_for_status() | |
| obj = r.json() | |
| items = obj.get("items") or obj.get("recommendations") or [] | |
| return [str(x) for x in items] | |
| def call_python(model_cfg: dict, song_ratings: List[Dict[str, any]]) -> List[Tuple[str, str]]: | |
| import threading | |
| import time | |
| dotted = model_cfg["callable"] # e.g., "team_alpha.src.recommender.query" | |
| timeout = float(model_cfg.get("timeout", 8)) # 8 seconds timeout for Python models | |
| mod_name, fn_name = dotted.rsplit(".", 1) | |
| mod = importlib.import_module(mod_name) | |
| query_fn = getattr(mod, fn_name) | |
| # Simple timeout mechanism using threading | |
| result = [None] | |
| exception = [None] | |
| def run_model(): | |
| try: | |
| result[0] = query_fn(song_ratings) | |
| except Exception as e: | |
| exception[0] = e | |
| thread = threading.Thread(target=run_model) | |
| thread.daemon = True | |
| thread.start() | |
| thread.join(timeout=timeout) | |
| if thread.is_alive(): | |
| raise TimeoutError(f"Python model '{model_cfg.get('name', 'unknown')}' timed out after {timeout} seconds") | |
| if exception[0]: | |
| raise exception[0] | |
| return result[0] | |
| def get_recs(model_name: str, song_ratings: List[Dict[str, any]]) -> List[Tuple[str, str]]: | |
| """ | |
| Get recommendations from a model. | |
| Args: | |
| model_name: Name of the model to use | |
| song_ratings: List of song ratings | |
| Returns: | |
| List of (spotify_id, track_name) tuples | |
| """ | |
| cfg = MODELS[model_name] | |
| t = cfg["type"] | |
| if t == "http": | |
| items = call_http(cfg, song_ratings) | |
| elif t == "python": | |
| items = call_python(cfg, song_ratings) | |
| else: | |
| raise ValueError(f"Unknown model type: {t}") | |
| # Handle both old format (List[str]) and new format (List[Tuple[str, str]]) | |
| if items and isinstance(items[0], (list, tuple)) and len(items[0]) == 2: | |
| # New format: List[Tuple[str, str]] - (spotify_id, track_name) | |
| result = items | |
| else: | |
| # Old format: List[str] - convert to new format with empty spotify_ids | |
| result = [("", str(i).strip()) for i in items if str(i).strip()] | |
| CACHE[model_name] = result | |
| # persist sparingly | |
| if len(CACHE) % 20 == 0: | |
| save_cache(CACHE) | |
| return result | |
| # ---------------- Logging ---------------- | |
| def log_vote(payload: dict): | |
| with VOTES_LOG.open("a", encoding="utf-8") as f: | |
| f.write(json.dumps(payload, ensure_ascii=False) + "\n") | |
| # ---------------- UI helpers ---------------- | |
| CSS = """ | |
| .card { border: 1px solid #e5e7eb; border-radius: 14px; padding: 12px; text-align: left; } | |
| .card h3 { margin: 0 0 8px 0; font-size: 16px; } | |
| .card .meta { color: #6b7280; font-size: 13px; margin-bottom: 8px; } | |
| .items { font-family: ui-monospace, SFMono-Regular, Menlo, monospace; font-size: 14px; } | |
| .items li { margin: 2px 0; } | |
| #vote-row button { font-weight: 700; } | |
| """ | |
| def render_list(title: str, song: str, items: List[Tuple[str, str]], k: int = TOPK_SHOW) -> str: | |
| """ | |
| Render a list of recommendations with Spotify players. | |
| Args: | |
| title: Title for the recommendation list | |
| song: Song name for context | |
| items: List of (spotify_id, track_name) tuples | |
| k: Number of items to show | |
| Returns: | |
| HTML string with recommendations and embedded Spotify players | |
| """ | |
| if not items: | |
| return f'<div class="card"><h3>{title}</h3><div class="meta">Song: <b>{song}</b></div><em>No items returned.</em></div>' | |
| top = items[:k] | |
| # Create list items with Spotify players | |
| li_items = [] | |
| spotify_players = [] | |
| for i, (spotify_id, track_name) in enumerate(top): | |
| li_items.append(f"<li>{track_name}</li>") | |
| if spotify_id: | |
| player_html = create_spotify_player_html(spotify_id, width="100%", height="80") | |
| spotify_players.append(f""" | |
| <div style="margin: 10px 0; padding: 10px; border: 1px solid #e5e7eb; border-radius: 8px;"> | |
| <h4 style="margin: 0 0 5px 0; font-size: 14px;">{i+1}. {track_name}</h4> | |
| {player_html} | |
| </div> | |
| """) | |
| else: | |
| spotify_players.append(f""" | |
| <div style="margin: 10px 0; padding: 10px; border: 1px solid #e5e7eb; border-radius: 8px;"> | |
| <h4 style="margin: 0 0 5px 0; font-size: 14px;">{i+1}. {track_name}</h4> | |
| <p style="color: #6b7280; font-size: 12px;">No preview available</p> | |
| </div> | |
| """) | |
| li = "".join(li_items) | |
| players_html = "".join(spotify_players) | |
| return f""" | |
| <div class="card"> | |
| <h3>{title}</h3> | |
| <div class="meta">Song: <b>{song}</b> · Showing top {len(top)}</div> | |
| <ol class="items">{li}</ol> | |
| <div style="margin-top: 15px;"> | |
| <h4>🎵 Preview Tracks:</h4> | |
| {players_html} | |
| </div> | |
| </div> | |
| """ | |
| # ---------------- Gradio App ---------------- | |
| with gr.Blocks(title="Recommender Arena (Song Ratings → A/B Vote)", css=CSS) as demo: | |
| gr.Markdown("# 🎶 Tune Duel") | |
| gr.Markdown("Rate **your favourite songs** (1-5 stars). Pick two models (or random). Compare the recommendations and vote.") | |
| gr.Markdown("💡 **Tips**: Start typing a song name to see matching tracks, click 🎲 Random to get a random track, or click ▶️ Play to start the Spotify player!") | |
| # Spotify player display at the top | |
| with gr.Row(): | |
| spotify_player_display = gr.HTML(label="🎵 Now Playing", | |
| value="<p>Enter a track name and click ▶️ to start playing!</p>") | |
| # Test button to show sample tracks | |
| with gr.Row(): | |
| test_btn = gr.Button("🔍 Show Sample Tracks", variant="secondary") | |
| def show_sample_tracks(): | |
| """Show sample tracks for testing""" | |
| sample_tracks = list(VALID_TRACKS)[:5] | |
| return f"<h3>Sample tracks in database:</h3><ul>" + "".join(f"<li>{track}</li>" for track in sample_tracks) + "</ul>" | |
| test_btn.click(show_sample_tracks, outputs=[spotify_player_display]) | |
| model_names = sorted(MODELS.keys()) | |
| # Song ratings input - using a more flexible approach | |
| with gr.Row(): | |
| with gr.Column(scale=14): # main content area | |
| with gr.Row(): | |
| song1 = gr.Textbox(label="Song 1", placeholder="e.g., '22 by Taylor Swift'", lines=1, scale=8) | |
| rating1 = gr.Slider(minimum=1, maximum=5, step=1, value=5, label="Rating 1", scale=2) | |
| song1_suggestions = gr.Dropdown(label="Suggestions", choices=[], interactive=True, visible=False, scale=3) | |
| with gr.Column(scale=1, elem_classes="button-col"): | |
| song1_random_btn = gr.Button("🎲", variant="secondary", elem_classes="small-btn") | |
| song1_play_btn = gr.Button("▶️", variant="primary", elem_classes="small-btn") | |
| with gr.Row(): | |
| with gr.Column(scale=14): | |
| with gr.Row(): | |
| song2 = gr.Textbox(label="Song 2", placeholder="e.g., 'Paranoid Android by Radiohead'", lines=1, scale=8) | |
| rating2 = gr.Slider(minimum=1, maximum=5, step=1, value=5, label="Rating 2", scale=2) | |
| song2_suggestions = gr.Dropdown(label="Suggestions", choices=[], interactive=True, visible=False, scale=3) | |
| with gr.Column(scale=1, elem_classes="button-col"): | |
| song2_random_btn = gr.Button("🎲", variant="secondary", elem_classes="small-btn") | |
| song2_play_btn = gr.Button("▶️", variant="primary", elem_classes="small-btn") | |
| with gr.Row(): | |
| with gr.Column(scale=14): | |
| with gr.Row(): | |
| song3 = gr.Textbox(label="Song 3", placeholder="e.g., 'Hey Jude by The Beatles'", lines=1, scale=8) | |
| rating3 = gr.Slider(minimum=1, maximum=5, step=1, value=5, label="Rating 3", scale=2) | |
| song3_suggestions = gr.Dropdown(label="Suggestions", choices=[], interactive=True, visible=False, scale=3) | |
| with gr.Column(scale=1, elem_classes="button-col"): | |
| song3_random_btn = gr.Button("🎲", variant="secondary", elem_classes="small-btn") | |
| song3_play_btn = gr.Button("▶️", variant="primary", elem_classes="small-btn") | |
| # Additional songs container (initially hidden) | |
| additional_songs_container = gr.Column(visible=False) | |
| with additional_songs_container: | |
| with gr.Row(): | |
| with gr.Column(scale=14): | |
| with gr.Row(): | |
| song4 = gr.Textbox(label="Song 4", placeholder="e.g., 'Bohemian Rhapsody by Queen'", lines=1, scale=8) | |
| rating4 = gr.Slider(minimum=1, maximum=5, step=1, value=5, label="Rating 4", scale=2) | |
| song4_suggestions = gr.Dropdown(label="Suggestions", choices=[], interactive=True, visible=False, scale=3) | |
| with gr.Column(scale=1, elem_classes="button-col"): | |
| song4_random_btn = gr.Button("🎲", variant="secondary", elem_classes="small-btn") | |
| song4_play_btn = gr.Button("▶️", variant="primary", elem_classes="small-btn") | |
| with gr.Row(): | |
| with gr.Column(scale=14): | |
| with gr.Row(): | |
| song5 = gr.Textbox(label="Song 5", placeholder="e.g., 'Stairway to Heaven by Led Zeppelin'", lines=1, scale=8) | |
| rating5 = gr.Slider(minimum=1, maximum=5, step=1, value=5, label="Rating 5", scale=2) | |
| song5_suggestions = gr.Dropdown(label="Suggestions", choices=[], interactive=True, visible=False, scale=3) | |
| with gr.Column(scale=1, elem_classes="button-col"): | |
| song5_random_btn = gr.Button("🎲", variant="secondary",elem_classes="small-btn") | |
| song5_play_btn = gr.Button("▶️", variant="primary", elem_classes="small-btn") | |
| # Add more songs button | |
| add_song_btn = gr.Button("Add More Songs (4-5)", variant="secondary") | |
| # Fill all random button | |
| fill_all_random_btn = gr.Button("🎲 Fill All Random", variant="primary") | |
| def toggle_additional_songs(): | |
| return gr.Column(visible=True) | |
| def fill_all_random(): | |
| """Fill all song fields with random tracks""" | |
| return [get_random_track() for _ in range(5)] | |
| add_song_btn.click(toggle_additional_songs, outputs=[additional_songs_container]) | |
| fill_all_random_btn.click(fill_all_random, outputs=[song1, song2, song3, song4, song5]) | |
| # Real-time track suggestions functions | |
| def update_suggestions(query: str, suggestions_dropdown): | |
| """Update suggestions dropdown based on query""" | |
| if not query or len(query.strip()) < 2: | |
| return gr.Dropdown(choices=[], visible=False) | |
| matches = find_matching_tracks(query, max_results=8) | |
| if matches: | |
| return gr.Dropdown(choices=matches, visible=True) | |
| else: | |
| return gr.Dropdown(choices=[], visible=False) | |
| def select_suggestion(suggestion: str, textbox): | |
| """When user selects a suggestion, update the textbox""" | |
| if suggestion: | |
| return suggestion | |
| return textbox | |
| # Set up real-time suggestions for all song inputs | |
| song1.change(update_suggestions, inputs=[song1, song1_suggestions], outputs=[song1_suggestions]) | |
| song1_suggestions.change(select_suggestion, inputs=[song1_suggestions, song1], outputs=[song1]) | |
| song2.change(update_suggestions, inputs=[song2, song2_suggestions], outputs=[song2_suggestions]) | |
| song2_suggestions.change(select_suggestion, inputs=[song2_suggestions, song2], outputs=[song2]) | |
| song3.change(update_suggestions, inputs=[song3, song3_suggestions], outputs=[song3_suggestions]) | |
| song3_suggestions.change(select_suggestion, inputs=[song3_suggestions, song3], outputs=[song3]) | |
| song4.change(update_suggestions, inputs=[song4, song4_suggestions], outputs=[song4_suggestions]) | |
| song4_suggestions.change(select_suggestion, inputs=[song4_suggestions, song4], outputs=[song4]) | |
| song5.change(update_suggestions, inputs=[song5, song5_suggestions], outputs=[song5_suggestions]) | |
| song5_suggestions.change(select_suggestion, inputs=[song5_suggestions, song5], outputs=[song5]) | |
| # Random track button handlers | |
| song1_random_btn.click(lambda: get_random_track(), outputs=[song1]) | |
| song2_random_btn.click(lambda: get_random_track(), outputs=[song2]) | |
| song3_random_btn.click(lambda: get_random_track(), outputs=[song3]) | |
| song4_random_btn.click(lambda: get_random_track(), outputs=[song4]) | |
| song5_random_btn.click(lambda: get_random_track(), outputs=[song5]) | |
| # Play button handlers - start Spotify player | |
| song1_play_btn.click(get_spotify_player, inputs=[song1], outputs=[spotify_player_display]) | |
| song2_play_btn.click(get_spotify_player, inputs=[song2], outputs=[spotify_player_display]) | |
| song3_play_btn.click(get_spotify_player, inputs=[song3], outputs=[spotify_player_display]) | |
| song4_play_btn.click(get_spotify_player, inputs=[song4], outputs=[spotify_player_display]) | |
| song5_play_btn.click(get_spotify_player, inputs=[song5], outputs=[spotify_player_display]) | |
| with gr.Row(): | |
| model_a = gr.Dropdown(choices=model_names, value=random.choice(model_names), label="Model A") | |
| model_b = gr.Dropdown(choices=model_names, value=random.choice(model_names), label="Model B") | |
| rand_pair_btn = gr.Button("Random Pair") | |
| recommend_btn = gr.Button("Recommend") # <-- NEW | |
| with gr.Row(): | |
| list_a = gr.HTML() | |
| list_b = gr.HTML() | |
| with gr.Row(elem_id="vote-row"): | |
| btn_a = gr.Button("A Wins", variant="primary") | |
| btn_tie = gr.Button("Tie", variant="secondary") | |
| btn_b = gr.Button("B Wins", variant="primary") | |
| btn_skip = gr.Button("Skip", variant="secondary") | |
| leaderboard = gr.Dataframe(headers=["model", "elo"], interactive=False, label="Live Leaderboard (Elo)") #, height=400) | |
| # states | |
| elo_state = gr.State(load_elo()) | |
| last_payload = gr.State({}) # remember last (song, A, B) for logging | |
| def random_pair(cur_a, cur_b): | |
| # ensure distinct | |
| if len(model_names) < 2: | |
| return gr.Warning("Need at least two models.") | |
| a, b = random.sample(model_names, 2) | |
| return a, b | |
| rand_pair_btn.click(random_pair, inputs=[model_a, model_b], outputs=[model_a, model_b]) | |
| def render_empty(title: str, msg: str) -> str: | |
| return f""" | |
| <div class="card"> | |
| <h3>{title}</h3> | |
| <div class="meta"></div> | |
| <em>{msg}</em> | |
| </div> | |
| """ | |
| def refresh_lists(song1, rating1, song2, rating2, song3, rating3, song4, rating4, song5, rating5, a: str, b: str, elo: dict, prev_payload: dict): | |
| # Parse songs and ratings from the input | |
| song_ratings = [] | |
| songs_and_ratings = [(song1, rating1), (song2, rating2), (song3, rating3), (song4, rating4), (song5, rating5)] | |
| # Validate tracks and collect validation messages | |
| validation_messages = [] | |
| for song, rating in songs_and_ratings: | |
| if song and song.strip(): | |
| is_valid, message = validate_track_name(song, VALID_TRACKS) | |
| if not is_valid: | |
| validation_messages.append(f"'{song}': {message}") | |
| else: | |
| spotify_id = get_spotify_track_id(song.strip()) | |
| song_ratings.append({ | |
| "song": song.strip(), | |
| "rating": int(rating), | |
| "spotify_id": spotify_id or "" | |
| }) | |
| # If no valid songs, warn and keep previous state/UI | |
| if not song_ratings: | |
| gr.Warning("Please enter at least one song with a rating.") | |
| df = leaderboard_df(elo) | |
| if prev_payload: | |
| # keep the previous lists as-is | |
| pa = render_list(prev_payload["A"], f"{len(prev_payload['song_ratings'])} songs", get_recs(prev_payload["A"], prev_payload["song_ratings"])) | |
| pb = render_list(prev_payload["B"], f"{len(prev_payload['song_ratings'])} songs", get_recs(prev_payload["B"], prev_payload["song_ratings"])) | |
| return pa, pb, prev_payload, df | |
| # or show helpful placeholders | |
| empty_a = render_empty("Model A", "Enter songs with ratings and click Recommend.") | |
| empty_b = render_empty("Model B", "Enter songs with ratings and click Recommend.") | |
| return empty_a, empty_b, prev_payload, df | |
| if a == b: | |
| gr.Warning("Pick two different models.") | |
| df = leaderboard_df(elo) | |
| if prev_payload: | |
| pa = render_list(prev_payload["A"], f"{len(prev_payload['song_ratings'])} songs", get_recs(prev_payload["A"], prev_payload["song_ratings"])) | |
| pb = render_list(prev_payload["B"], f"{len(prev_payload['song_ratings'])} songs", get_recs(prev_payload["B"], prev_payload["song_ratings"])) | |
| return pa, pb, prev_payload, df | |
| empty_a = render_empty("Model A", "Pick two different models.") | |
| empty_b = render_empty("Model B", "Pick two different models.") | |
| return empty_a, empty_b, prev_payload, df | |
| # Valid -> fetch and render | |
| try: | |
| items_a = get_recs(a, song_ratings) | |
| items_b = get_recs(b, song_ratings) | |
| except Exception as e: | |
| gr.Warning(f"Failed to get recommendations: {e}") | |
| df = leaderboard_df(elo) | |
| if prev_payload: | |
| pa = render_list(prev_payload["A"], f"{len(prev_payload['song_ratings'])} songs", get_recs(prev_payload["A"], prev_payload["song_ratings"])) | |
| pb = render_list(prev_payload["B"], f"{len(prev_payload['song_ratings'])} songs", get_recs(prev_payload["B"], prev_payload["song_ratings"])) | |
| return pa, pb, prev_payload, df | |
| return render_empty("Model A", "Error fetching."), render_empty("Model B", "Error fetching."), prev_payload, df | |
| html_a = render_list(a, f"{len(song_ratings)} songs", items_a) | |
| html_b = render_list(b, f"{len(song_ratings)} songs", items_b) | |
| df = leaderboard_df(elo) | |
| payload = {"song_ratings": song_ratings, "A": a, "B": b} | |
| return html_a, html_b, payload, df | |
| # Fetch lists whenever inputs change meaningfully | |
| rand_pair_btn.click( | |
| random_pair, inputs=[model_a, model_b], outputs=[model_a, model_b] | |
| ) | |
| recommend_btn.click( | |
| refresh_lists, | |
| inputs=[song1, rating1, song2, rating2, song3, rating3, song4, rating4, song5, rating5, model_a, model_b, elo_state, last_payload], | |
| outputs=[list_a, list_b, last_payload, leaderboard], | |
| ) | |
| model_a.change(refresh_lists, inputs=[song1, rating1, song2, rating2, song3, rating3, song4, rating4, song5, rating5, model_a, model_b, elo_state], outputs=[list_a, list_b, last_payload, leaderboard]) | |
| model_b.change(refresh_lists, inputs=[song1, rating1, song2, rating2, song3, rating3, song4, rating4, song5, rating5, model_a, model_b, elo_state], outputs=[list_a, list_b, last_payload, leaderboard]) | |
| def vote(action: str, elo: dict, payload: dict, request: gr.Request): | |
| if not payload: | |
| raise gr.Error("Load recommendations first (enter songs with ratings).") | |
| song_ratings = payload["song_ratings"]; a = payload["A"]; b = payload["B"] | |
| outcome = "Tie" if action == "tie" else ("A" if action == "a" else "B") | |
| # update elo | |
| update_elo(elo, a, b, outcome) | |
| save_elo(elo) | |
| df = leaderboard_df(elo) | |
| # log | |
| log_vote({ | |
| "ts": datetime.utcnow().isoformat(), | |
| "client_ip": getattr(request, "client", None).host if request and request.client else None, | |
| "song_ratings": song_ratings, "model_a": a, "model_b": b, "outcome": outcome, | |
| }) | |
| return elo, df | |
| btn_a.click(vote, inputs=[gr.State("a"), elo_state, last_payload], outputs=[elo_state, leaderboard]) | |
| btn_b.click(vote, inputs=[gr.State("b"), elo_state, last_payload], outputs=[elo_state, leaderboard]) | |
| btn_tie.click(vote, inputs=[gr.State("tie"), elo_state, last_payload], outputs=[elo_state, leaderboard]) | |
| btn_skip.click(lambda elo: (elo, leaderboard_df(elo)), inputs=[elo_state], outputs=[elo_state, leaderboard]) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=20).launch(server_name="0.0.0.0", server_port=7860, share=True) | |