nb-tts-rubric / app.py
kathiasi's picture
New try on progress bar
8dfc10c verified
import gradio as gr
import os
import csv
import fcntl
from datetime import datetime
import uuid
import yaml # You need to install this: pip install pyyaml
import glob
import random
import json
import pandas as pd
import io
# --- Hugging Face Functionality Notes ---
# To save results to a private Hugging Face dataset, you must:
# 1. Install the required libraries: pip install huggingface_hub datasets
# 2. Set the following environment variables before running the script:
# - HF_TOKEN: Your Hugging Face access token with write permissions.
# - HF_DATASET_ID: The ID of the private dataset repo (e.g., "username/my-dataset").
# If these are not set, saving to HF Hub will be skipped.
# --- Start of Local Mode Implementation ---
IS_LOCAL_MODE = os.environ.get("GRADIO_LOCAL_MODE", "false").lower() in ["true", "1"]
if IS_LOCAL_MODE:
print("Running in LOCAL mode. Hugging Face functionalities are disabled.")
HfApi = None
else:
try:
from huggingface_hub import HfApi, hf_hub_download
print("Hugging Face libraries found. HF push functionality is available.")
except ImportError:
print("Hugging Face libraries not found. HF push functionality will be disabled.")
HfApi = None
# --- End of Local Mode Implementation ---
# --- Configuration Loading ---
def load_config(config_path='config.yaml'):
"""Loads the UI and criteria configuration from a YAML file."""
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
if 'criteria' not in config or not isinstance(config['criteria'], list):
raise ValueError("Config must contain a list of 'criteria'.")
return config
except FileNotFoundError:
return None
except Exception as e:
print(f"ERROR: Could not parse {config_path}: {e}")
return None
def find_config_files():
"""Finds all .yaml and .yml files in the root directory."""
return glob.glob("*.yaml") + glob.glob("*.yml")
# --- Static & File I/O Functions ---
OUTPUT_CSV = "responses.csv"
MAX_CRITERIA = 15 # Maximum number of sliders to support
def list_samples(samples_dir):
"""Lists audio files from a specified directory."""
if not os.path.isdir(samples_dir):
print(f"WARNING: Samples directory '{samples_dir}' not found.")
return []
files = [f for f in os.listdir(samples_dir) if f.lower().endswith(('.wav', '.mp3', '.ogg', '.flac'))]
files.sort()
return files
def save_responses_to_hf(rows, repo_id: str | None = None, token: str | None = None):
"""
Append new rows to a CSV file in a private Hugging Face dataset.
- Reads the existing CSV (if present).
- Appends new rows.
- Uploads the updated file back to the repo.
Each 'row' should be a dict with consistent keys.
NOTE:
- Replaces the entire CSV on each update (no true append on the server side).
- Use small/medium datasets; large ones should use the `datasets` library instead.
"""
if HfApi is None:
return {"status": "hf_unavailable", "reason": "missing_packages"}
token = token or os.environ.get("HF_TOKEN")
repo_id = repo_id or os.environ.get("HF_DATASET_ID")
if not token or not repo_id:
return {"status": "hf_skipped", "reason": "missing_token_or_repo_env"}
api = HfApi(token=token)
path_in_repo = "data/responses.csv" # fixed CSV location in repo
repo_err = None
# Ensure dataset exists
try:
api.create_repo(repo_id=repo_id, repo_type="dataset", private=True, exist_ok=True)
except Exception as e:
repo_err = str(e)
# Try downloading existing CSV
existing_df = pd.DataFrame()
try:
local_path = hf_hub_download(
repo_id=repo_id,
filename=path_in_repo,
repo_type="dataset",
token=token,
)
existing_df = pd.read_csv(local_path)
except Exception as e:
print("file", path_in_repo, "couldn't be found / read", str(e))
# File doesn't exist or is unreadable — start fresh
pass
# Convert new rows to DataFrame and append
new_df = pd.DataFrame(rows)
combined_df = pd.concat([existing_df, new_df], ignore_index=True)
print(combined_df)
# Save to memory as CSV
csv_buffer = io.StringIO()
combined_df.to_csv(csv_buffer, index=False)
csv_bytes = csv_buffer.getvalue().encode("utf-8")
# Upload the updated CSV
try:
api.upload_file(
path_or_fileobj=csv_bytes,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type="dataset",
)
except Exception as e:
print(str(e))
return {"status": "hf_push_error", "error": str(e), "repo_error": repo_err}
return {"status": "hf_pushed", "rows_added": len(rows), "repo": repo_id, "repo_error": repo_err}
def _save_responses_to_hf(rows, repo_id: str | None = None, token: str | None = None):
"""
Push a list of dict rows to a private HF dataset, one JSON file per row.
NOTE: This approach saves each response as an individual file. While this
prevents data loss from overwriting a single file, be aware of the following:
- Performance: Uploading many small files can be slower than a single large one.
- Scalability: A very large number of files (e.g., millions) can make the
dataset repository unwieldy to browse or clone.
- Loading Data: To load this data back into a `datasets.Dataset` object, you
will need to point to the specific files, for example:
`load_dataset('json', data_files='path/to/your/repo/data/*.json')`
"""
if HfApi is None:
return {"status": "hf_unavailable", "reason": "missing_packages"}
token = token or os.environ.get("HF_TOKEN")
repo_id = repo_id or os.environ.get("HF_DATASET_ID")
if not token or not repo_id:
return {"status": "hf_skipped", "reason": "missing_token_or_repo_env"}
api = HfApi(token=token)
repo_err = None
try:
api.create_repo(repo_id=repo_id, repo_type="dataset", private=True, exist_ok=True)
except Exception as e:
repo_err = str(e)
# Process each row, uploading it as a separate JSON file
num_pushed = 0
errors = []
for row_dict in rows:
try:
# Create a unique filename. Using a UUID is the most robust method.
filename = f"{uuid.uuid4()}.json"
# Place files in a 'data' subdirectory to keep the repo root clean.
path_in_repo = f"data/{filename}"
# Convert the dictionary to JSON bytes for uploading
json_bytes = json.dumps(row_dict, indent=2).encode("utf-8")
api.upload_file(
path_or_obj=json_bytes,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type="dataset",
)
num_pushed += 1
except Exception as e:
errors.append(str(e))
if errors:
print("json errors", errors, "repo errors", repo_err)
return {"status": "hf_push_error", "pushed": num_pushed, "total": len(rows), "errors": errors, "repo_error": repo_err}
return {"status": "hf_pushed", "rows": len(rows), "repo": repo_id, "repo_error": repo_err}
def save_response(sample, audio_path, annotator, session_id, user_email, comment, scores, config):
"""Saves a response row locally and attempts to push to Hugging Face Hub."""
os.makedirs(os.path.dirname(OUTPUT_CSV) or '.', exist_ok=True)
criteria_labels = [c['label'] for c in config['criteria']]
header = ["timestamp", "sample", "audio_path", "annotator", "session_id", "user_email"] + criteria_labels + ["comment"]
active_scores = list(scores)[:len(criteria_labels)]
row = [datetime.utcnow().isoformat(), sample, audio_path, annotator, session_id, user_email] + active_scores + [comment]
write_header = not os.path.exists(OUTPUT_CSV)
with open(OUTPUT_CSV, "a", newline='', encoding='utf-8') as f:
try: fcntl.flock(f.fileno(), fcntl.LOCK_EX)
except Exception: pass
writer = csv.writer(f)
if write_header: writer.writerow(header)
writer.writerow(row)
try: fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except Exception: pass
# --- Hugging Face Push Logic ---
hf_result = None
if not IS_LOCAL_MODE:
try:
hf_record = dict(zip(header, row))
hf_result = save_responses_to_hf([hf_record])
except Exception as e:
print(e)
hf_result = {"status": "hf_error", "error": str(e)}
return {"status": "saved", "sample": sample, "hf": hf_result}
# --- Gradio UI Definition ---
def make_ui():
def make_explainer_fn(criterion_index):
def explainer(value, config):
if not config or criterion_index >= len(config.get('criteria', [])): return ""
criterion = config['criteria'][criterion_index]
try: iv = int(value)
except (ValueError, TypeError): iv = value
text = criterion['explanations'].get(iv, "No description for this score.")
return f"**{criterion['label']} ({iv}/{criterion['max']}):** {text}"
return explainer
#with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
with gr.Blocks() as demo:
# --- STATE MANAGEMENT ---
samples_list = gr.State()
current_index = gr.State(0)
config_state = gr.State()
session_id_global = gr.State()
# --- SETUP UI (Visible at start) ---
with gr.Group() as setup_group:
gr.Markdown("# Evaluation Setup")
gr.Markdown("Please provide your details and select the evaluation setup to begin.")
#config_dropdown = gr.Dropdown(choices=find_config_files(), label="Select Evaluation", value=find_config_files()[0] if find_config_files() else "")
config_dropdown = gr.Dropdown(choices=find_config_files(), label="Select Evaluation", value=None) # if find_config_files() else "")
instructions_md = gr.Markdown(visible=False, elem_classes="instructions")
with gr.Accordion("Annotator Info", open=True):
annotator_global = gr.Textbox(label="Annotator ID (automatically generated for you)", lines=1)
user_email_global = gr.Textbox(label="User email (optional)", lines=1)
start_button = gr.Button("Start Evaluation", variant="primary")
config_error_md = gr.Markdown("", visible=False)
# --- MAIN EVALUATION UI (Initially hidden) ---
with gr.Group(visible=False) as main_group:
title_md = gr.Markdown("# Evaluation UI")
header_md = gr.Markdown("")
progress_md = gr.Markdown("Sample 1 of X")
progress_bar = gr.Progress(track_tqdm=False)
with gr.Row():
with gr.Column(scale=1, variant='panel'):
sample_name_md = gr.Markdown("### Audio File")
gr.Markdown("---")
evaluation_audio = gr.Audio(label="Audio for Evaluation")
gr.Markdown("---")
submit_btn = gr.Button("Save & Next", variant="primary", interactive=False)
status = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=2, variant='panel'):
gr.Markdown("### Scoring Criteria")
slider_explanation_md = gr.Markdown("_Move a slider to see the description for each score._")
gr.Markdown("---")
sliders = [gr.Slider(visible=False, interactive=True) for _ in range(MAX_CRITERIA)]
gr.Markdown("---")
comment = gr.Textbox(label="Comments (optional)", lines=4, placeholder="Enter any additional feedback here...")
# --- UI ELEMENT LISTS ---
main_ui_elements = [
title_md, header_md, progress_md, sample_name_md, evaluation_audio,
slider_explanation_md, comment, submit_btn, status, *sliders
]
# --- LOGIC & EVENTS ---
def load_sample(samples, index, config):
total_samples = len(samples)
updates = {}
if index >= total_samples:
completion_msg = f"**All {total_samples} samples completed! Thank you!**"
for el in main_ui_elements: updates[el] = gr.update(visible=False)
updates[progress_md] = gr.update(value=completion_msg, visible=True)
updates[status] = gr.update(value="Finished.", visible=True)
return updates
sample = samples[index]
samples_dir = config.get('samples_directory', 'sample-audios')
sample_path = os.path.join(samples_dir, sample)
sample_exists = os.path.exists(sample_path)
updates = {
progress_md: gr.update(value=f"Sample **{index + 1}** of **{total_samples}**", visible=True),
sample_name_md: gr.update(value=f"### File: `{sample}`", visible=True),
evaluation_audio: gr.update(value=sample_path if sample_exists else None, visible=sample_exists),
slider_explanation_md: gr.update(value="_Move a slider to see the description for each score._", visible=True),
comment: gr.update(value="", visible=True),
submit_btn: gr.update(value="Play audio to enable", interactive=False, visible=True),
status: gr.update(value="Ready.", visible=True)
}
num_criteria = len(config['criteria'])
for i in range(MAX_CRITERIA):
if i < num_criteria:
criterion = config['criteria'][i]
updates[sliders[i]] = gr.update(
label=criterion['label'], minimum=criterion['min'], maximum=criterion['max'],
step=criterion['step'], value=criterion['default'], visible=True
)
else:
updates[sliders[i]] = gr.update(visible=False, value=0)
return updates
def enable_submit_button():
return gr.update(value="Save & Next", interactive=True)
def update_instructions(config_path):
if not config_path: return gr.update(value="", visible=False)
config = load_config(config_path)
if config and 'instructions_markdown' in config:
return gr.update(value=config['instructions_markdown'], visible=True)
return gr.update(value="", visible=False)
def start_session(config_path, annotator_input=None):
if not config_path or not os.path.exists(config_path):
return {config_error_md: gr.update(value="**Error:** Please select a valid configuration file.", visible=True)}
config = load_config(config_path)
if config is None:
return {config_error_md: gr.update(value=f"**Error:** Could not load or parse `{config_path}`. Check console for details.", visible=True)}
samples_dir = config.get('samples_directory', 'sample-audios')
should_randomize = config.get('randomize_samples', False)
s_list = list_samples(samples_dir)
if not s_list:
return {config_error_md: gr.update(value=f"**Error:** No audio files found in directory: `{samples_dir}`", visible=True)}
if should_randomize: random.shuffle(s_list)
session_id = str(uuid.uuid4())
index = 0
updates = {
setup_group: gr.update(visible=False),
main_group: gr.update(visible=True),
config_error_md: gr.update(visible=False),
title_md: gr.update(value=f"# {config.get('title', 'Evaluation UI')}"),
header_md: gr.update(value=config.get('header_markdown', '')),
config_state: config,
session_id_global: session_id,
samples_list: s_list,
current_index: index,
}
# Determine annotator ID: use provided value or generate a random one
if annotator_input and str(annotator_input).strip():
annotator = str(annotator_input).strip()
else:
annotator = f"anon-{uuid.uuid4().hex[:8]}"
# Update annotator textbox in the setup UI so the user sees their assigned ID
updates[annotator_global] = gr.update(value=annotator)
sample_updates = load_sample(s_list, index, config)
updates.update(sample_updates)
return updates
def save_and_next(index, samples, annotator, sid, email, comment, config, *scores):
sample = samples[index]
samples_dir = config.get('samples_directory', 'sample-audios')
sample_path = os.path.join(samples_dir, sample)
save_status = save_response(sample, sample_path, annotator, sid, email, comment, scores, config)
next_index = index + 1
total_samples = len(samples)
# Update progress bar
progress_value = (next_index) / total_samples if total_samples > 0 else 0
progress_bar(progress_value)
updates_dict = load_sample(samples, next_index, config)
# Provide more detailed status, including HF info if available
status_message = f"Saved {sample} locally."
if save_status.get('hf'):
hf_stat = save_status['hf'].get('status', 'hf_unknown')
status_message += f" HF status: {hf_stat}."
updates_dict[status] = gr.update(value=status_message)
ordered_updates = [updates_dict.get(el) for el in main_ui_elements]
return [next_index] + ordered_updates
# --- Event Wiring ---
config_dropdown.change(
update_instructions, inputs=[config_dropdown], outputs=[instructions_md]
).then(None, None, None, js="() => { document.getElementById('component-0').scrollIntoView(); }")
start_button.click(
start_session,
inputs=[config_dropdown, annotator_global],
outputs=[
setup_group, main_group, config_error_md, annotator_global, *main_ui_elements,
config_state, session_id_global, samples_list, current_index
]
)
submit_btn.click(
save_and_next,
inputs=[current_index, samples_list, annotator_global, session_id_global, user_email_global, comment, config_state, *sliders],
outputs=[current_index, *main_ui_elements]
)
for i, slider in enumerate(sliders):
slider.change(make_explainer_fn(i), inputs=[slider, config_state], outputs=[slider_explanation_md])
evaluation_audio.play(fn=enable_submit_button, inputs=None, outputs=[submit_btn])
demo.load(update_instructions, inputs=config_dropdown, outputs=instructions_md)
return demo
if __name__ == "__main__":
app = make_ui()
app.launch(server_name="0.0.0.0", server_port=7860)