ACE-Step-Training / src /lora_trainer.py
pedroapfilho's picture
Use HF dataset repo as source of truth for dataset.json
6c32e21 unverified
"""
HuggingFace Dataset Download Utility for LoRA Training Studio.
Provides a helper to download audio datasets from HuggingFace Hub.
The actual training pipeline lives in acestep/training/.
"""
import logging
import os
import shutil
from pathlib import Path
from typing import Tuple
logger = logging.getLogger(__name__)
AUDIO_SUFFIXES = {".wav", ".mp3", ".flac", ".ogg", ".opus"}
def download_hf_dataset(
dataset_id: str,
max_files: int = 50,
offset: int = 0,
) -> Tuple[str, str]:
"""
Download a subset of audio files from a HuggingFace dataset repo.
Also pulls dataset.json from the repo if it exists (restoring labels
and preprocessed flags from a previous session).
Uses HF_TOKEN env var for authentication.
Returns:
Tuple of (output_dir, status_message)
"""
try:
from huggingface_hub import HfApi, hf_hub_download
api = HfApi()
token = os.environ.get("HF_TOKEN")
logger.info(f"Listing files in '{dataset_id}'...")
all_files = [
f.rfilename
for f in api.list_repo_tree(
dataset_id, repo_type="dataset", token=token, recursive=True
)
if hasattr(f, "rfilename")
and Path(f.rfilename).suffix.lower() in AUDIO_SUFFIXES
]
total_available = len(all_files)
selected = all_files[offset:offset + max_files]
if not selected:
return "", f"No audio files found in {dataset_id}"
logger.info(
f"Downloading {len(selected)}/{total_available} audio files..."
)
output_dir = Path("lora_training") / "hf" / dataset_id.replace("/", "_")
output_dir.mkdir(parents=True, exist_ok=True)
for i, filename in enumerate(selected):
logger.info(f" [{i + 1}/{len(selected)}] {filename}")
cached_path = hf_hub_download(
repo_id=dataset_id,
filename=filename,
repo_type="dataset",
token=token,
)
# Symlink from cache into our working dir so scan_directory finds them
dest = output_dir / Path(filename).name
if not dest.exists():
dest.symlink_to(cached_path)
# Pull dataset.json from repo if it exists (restores previous session state)
try:
cached_json = hf_hub_download(
repo_id=dataset_id,
filename="dataset.json",
repo_type="dataset",
token=token,
)
dest_json = output_dir / "dataset.json"
shutil.copy2(cached_json, str(dest_json))
logger.info("Pulled dataset.json from HF repo")
except Exception:
logger.info("No dataset.json in HF repo (first session)")
status = (
f"Downloaded {len(selected)} of {total_available} "
f"audio files from {dataset_id} (offset {offset})"
)
logger.info(status)
return str(output_dir), status
except ImportError:
msg = "huggingface_hub is not installed. Run: pip install huggingface_hub"
logger.error(msg)
return "", msg
except Exception as e:
msg = f"Failed to download dataset: {e}"
logger.error(msg)
return "", msg
def upload_dataset_json_to_hf(dataset_id: str, json_path: str) -> str:
"""Push dataset.json to the HF dataset repo for persistence across sessions."""
try:
from huggingface_hub import HfApi
token = os.environ.get("HF_TOKEN")
if not token:
return "HF_TOKEN not set — skipped HF sync"
api = HfApi()
api.upload_file(
path_or_fileobj=json_path,
path_in_repo="dataset.json",
repo_id=dataset_id,
repo_type="dataset",
token=token,
)
return f"Synced dataset.json to {dataset_id}"
except Exception as e:
msg = f"HF sync failed: {e}"
logger.error(msg)
return msg