OnyxlMunkey Cursor commited on
Commit
3cdb1cf
·
1 Parent(s): 564c9c9

Add Song Describer pipeline: prepare_song_describer.py, prepare_from_hf, preprocess/train CLI

Browse files
acestep/training_v2/cli/args.py CHANGED
@@ -108,6 +108,75 @@ def build_root_parser() -> argparse.ArgumentParser:
108
  help="Random seed (default: 42)",
109
  )
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  return root
112
 
113
 
 
108
  help="Random seed (default: 42)",
109
  )
110
 
111
+ # -- from-hf ------------------------------------------------------------
112
+ p_from_hf = subparsers.add_parser(
113
+ "from-hf",
114
+ help="Prepare ACE-Step dataset from a Hugging Face dataset (writes dataset.json + audio)",
115
+ formatter_class=formatter_class,
116
+ )
117
+ p_from_hf.add_argument(
118
+ "--dataset",
119
+ type=str,
120
+ required=True,
121
+ metavar="NAME",
122
+ help="Hugging Face dataset id (e.g. ashrafemam/crema-d or polyai/minds14)",
123
+ )
124
+ p_from_hf.add_argument(
125
+ "--output-dir",
126
+ type=str,
127
+ required=True,
128
+ metavar="DIR",
129
+ help="Output directory for dataset.json and audio/",
130
+ )
131
+ p_from_hf.add_argument(
132
+ "--split",
133
+ type=str,
134
+ default="train",
135
+ help="Dataset split (default: train)",
136
+ )
137
+ p_from_hf.add_argument(
138
+ "--config",
139
+ type=str,
140
+ default=None,
141
+ help="Dataset config name if required",
142
+ )
143
+ p_from_hf.add_argument(
144
+ "--caption-column",
145
+ type=str,
146
+ default=None,
147
+ help="Column to use as caption (default: auto-detect caption/text/sentence)",
148
+ )
149
+ p_from_hf.add_argument(
150
+ "--audio-column",
151
+ type=str,
152
+ default=None,
153
+ help="Column containing audio (default: auto-detect)",
154
+ )
155
+ p_from_hf.add_argument(
156
+ "--max-samples",
157
+ type=int,
158
+ default=None,
159
+ help="Max number of samples to export (default: all)",
160
+ )
161
+ p_from_hf.add_argument(
162
+ "--audio-subdir",
163
+ type=str,
164
+ default="audio",
165
+ help="Subdirectory name for audio under output-dir (default: audio)",
166
+ )
167
+ p_from_hf.add_argument(
168
+ "--json-filename",
169
+ type=str,
170
+ default="dataset.json",
171
+ help="Output JSON filename (default: dataset.json)",
172
+ )
173
+ p_from_hf.add_argument(
174
+ "--trust-remote-code",
175
+ action="store_true",
176
+ default=False,
177
+ help="Allow loading datasets with custom code",
178
+ )
179
+
180
  return root
181
 
182
 
acestep/training_v2/prepare_from_hf.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prepare an ACE-Step–compatible dataset from a Hugging Face dataset.
3
+
4
+ Loads a HF dataset (with an audio column and optional text/caption column),
5
+ writes audio files to a local directory and a dataset JSON in the format
6
+ expected by ``preprocess_audio_files``.
7
+
8
+ Usage (standalone, only needs ``pip install datasets``):
9
+ python prepare_from_hf_cli.py --dataset <HF_ID> --output-dir <DIR> [--max-samples N]
10
+
11
+ Or via train.py (full env):
12
+ python train.py from-hf --dataset <HF_ID> --output-dir <DIR>
13
+
14
+ Then preprocess and train:
15
+ python train.py preprocess --dataset-json <out>/dataset.json --tensor-output <pt_dir> ...
16
+ python train.py fixed --dataset-dir <pt_dir> ...
17
+
18
+ Datasets with an "audio" column (HF Audio feature) are supported; each row
19
+ must provide either a path or decoded bytes. Caption is taken from a
20
+ configurable column (default: caption/text/sentence). Note: google/MusicCaps
21
+ on HF has no audio column (YouTube refs only); use a dataset that includes
22
+ audio (e.g. polyai/minds14, ashrafemam/crema-d) or add audio separately.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import json
28
+ import logging
29
+ import shutil
30
+ from pathlib import Path
31
+ from typing import Any, Dict, List, Optional
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # Default column names to use as caption (first present wins)
36
+ DEFAULT_CAPTION_COLUMNS = ("caption", "text", "sentence", "description", "transcript")
37
+
38
+
39
+ def _infer_audio_column(column_names: List[str], first_row: Dict[str, Any]) -> Optional[str]:
40
+ for c in column_names:
41
+ if c in first_row and first_row[c] is not None:
42
+ val = first_row[c]
43
+ if isinstance(val, dict) and ("path" in val or "bytes" in val):
44
+ return c
45
+ if isinstance(val, str) and Path(val).suffix.lower() in {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".opus"}:
46
+ return c
47
+ return None
48
+
49
+
50
+ def _infer_caption_column(column_names: List[str], first_row: Dict[str, Any]) -> Optional[str]:
51
+ for name in DEFAULT_CAPTION_COLUMNS:
52
+ if name in column_names and first_row.get(name) and isinstance(first_row[name], str):
53
+ return name
54
+ return None
55
+
56
+
57
+ def _audio_path_from_row(audio_val: Any, audio_dir: Path, index: int, suffix: str = ".wav") -> Optional[Path]:
58
+ if audio_val is None:
59
+ return None
60
+ if isinstance(audio_val, dict):
61
+ path = audio_val.get("path")
62
+ if path and Path(path).is_file():
63
+ dest = audio_dir / f"sample_{index:06d}{suffix}"
64
+ try:
65
+ shutil.copy2(path, dest)
66
+ return dest
67
+ except OSError as e:
68
+ logger.warning("Copy failed for row %d: %s", index, e)
69
+ return None
70
+ raw_bytes = audio_val.get("bytes")
71
+ if raw_bytes is not None:
72
+ dest = audio_dir / f"sample_{index:06d}{suffix}"
73
+ try:
74
+ dest.write_bytes(raw_bytes)
75
+ return dest
76
+ except OSError as e:
77
+ logger.warning("Write failed for row %d: %s", index, e)
78
+ return None
79
+ return None
80
+ if isinstance(audio_val, str) and Path(audio_val).is_file():
81
+ ext = Path(audio_val).suffix.lower() or suffix
82
+ dest = audio_dir / f"sample_{index:06d}{ext}"
83
+ try:
84
+ shutil.copy2(audio_val, dest)
85
+ return dest
86
+ except OSError as e:
87
+ logger.warning("Copy failed for row %d: %s", index, e)
88
+ return None
89
+ return None
90
+
91
+
92
+ def prepare_from_hf(
93
+ dataset_name: str,
94
+ output_dir: str,
95
+ *,
96
+ split: str = "train",
97
+ config: Optional[str] = None,
98
+ caption_column: Optional[str] = None,
99
+ audio_column: Optional[str] = None,
100
+ max_samples: Optional[int] = None,
101
+ audio_subdir: str = "audio",
102
+ json_filename: str = "dataset.json",
103
+ trust_remote_code: bool = False,
104
+ ) -> Dict[str, Any]:
105
+ """Load a Hugging Face dataset and write ACE-Step dataset JSON + audio files.
106
+
107
+ Args:
108
+ dataset_name: Hugging Face dataset id (e.g. "google/MusicCaps" or "ashrafemam/crema-d").
109
+ output_dir: Directory to write dataset.json and audio files (into output_dir/<audio_subdir>).
110
+ split: Dataset split to use (default: "train").
111
+ config: Dataset config name if required.
112
+ caption_column: Column to use as caption; if None, inferred (caption/text/sentence/...).
113
+ audio_column: Column containing audio (path or Audio dict); if None, inferred.
114
+ max_samples: Limit number of samples (default: no limit).
115
+ audio_subdir: Subdirectory under output_dir for audio files (default: "audio").
116
+ json_filename: Name of the dataset JSON file (default: "dataset.json").
117
+ trust_remote_code: Passed to load_dataset.
118
+
119
+ Returns:
120
+ Dict with keys: output_dir, dataset_json, audio_dir, num_samples, caption_column, audio_column.
121
+ """
122
+ try:
123
+ from datasets import load_dataset
124
+ except ImportError:
125
+ raise ImportError("Install the 'datasets' package: pip install datasets") from None
126
+
127
+ out_path = Path(output_dir)
128
+ out_path.mkdir(parents=True, exist_ok=True)
129
+ audio_dir = out_path / audio_subdir
130
+ audio_dir.mkdir(parents=True, exist_ok=True)
131
+
132
+ load_kw: Dict[str, Any] = {"path": dataset_name, "split": split, "trust_remote_code": trust_remote_code}
133
+ if config:
134
+ load_kw["name"] = config
135
+ ds = load_dataset(**load_kw)
136
+ if hasattr(ds, "column_names"):
137
+ column_names = ds.column_names
138
+ first_row = ds[0] if len(ds) > 0 else {}
139
+ else:
140
+ column_names = list(ds[split].column_names)
141
+ first_row = ds[split][0] if len(ds[split]) > 0 else {}
142
+
143
+ audio_col = audio_column or _infer_audio_column(column_names, first_row)
144
+ if not audio_col:
145
+ raise ValueError(
146
+ "No audio column found. Ensure the dataset has an 'audio' column (Audio feature) "
147
+ "or pass --audio-column. For text-only datasets (e.g. MusicCaps with YouTube refs), "
148
+ "download audio separately and build the JSON manually."
149
+ )
150
+
151
+ caption_col = caption_column or _infer_caption_column(column_names, first_row)
152
+ data_split = ds[split] if hasattr(ds, "__getitem__") and split in ds else ds
153
+ total = len(data_split)
154
+ if max_samples is not None and max_samples > 0:
155
+ total = min(total, max_samples)
156
+
157
+ samples: List[Dict[str, Any]] = []
158
+ for i in range(total):
159
+ row = data_split[i]
160
+ audio_val = row.get(audio_col)
161
+ rel_audio_path = _audio_path_from_row(audio_val, audio_dir, i)
162
+ if rel_audio_path is None:
163
+ logger.debug("Skipping row %d: no resolvable audio", i)
164
+ continue
165
+ caption = (caption_col and row.get(caption_col)) or "[Instrumental]"
166
+ if not isinstance(caption, str):
167
+ caption = str(caption) if caption is not None else "[Instrumental]"
168
+ samples.append({
169
+ "filename": rel_audio_path.name,
170
+ "audio_path": str(rel_audio_path),
171
+ "caption": caption[:512],
172
+ "lyrics": "[Instrumental]",
173
+ "genre": "",
174
+ "bpm": None,
175
+ "keyscale": "",
176
+ "timesignature": "",
177
+ "duration": 0,
178
+ "is_instrumental": True,
179
+ })
180
+
181
+ dataset_json_path = out_path / json_filename
182
+ for s in samples:
183
+ s["audio_path"] = str(Path(audio_subdir) / Path(s["audio_path"]).name)
184
+ with open(dataset_json_path, "w", encoding="utf-8") as f:
185
+ json.dump({"samples": samples, "metadata": {"tag_position": "prepend", "genre_ratio": 0, "custom_tag": ""}}, f, indent=2)
186
+
187
+ return {
188
+ "output_dir": str(out_path),
189
+ "dataset_json": str(dataset_json_path),
190
+ "audio_dir": str(audio_dir),
191
+ "num_samples": len(samples),
192
+ "caption_column": caption_col,
193
+ "audio_column": audio_col,
194
+ }
prepare_from_hf_cli.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Standalone CLI to prepare an ACE-Step dataset from a Hugging Face dataset.
4
+
5
+ Only requires: pip install datasets
6
+
7
+ Usage:
8
+ python prepare_from_hf_cli.py --dataset <HF_DATASET_ID> --output-dir <DIR> [options]
9
+
10
+ Example:
11
+ python prepare_from_hf_cli.py --dataset polyai/minds14 --output-dir ./data/minds14 --split train
12
+
13
+ Then preprocess and train:
14
+ python train.py preprocess --dataset-json ./data/minds14/dataset.json --tensor-output ./pt_minds14 ...
15
+ python train.py fixed --dataset-dir ./pt_minds14 ...
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import argparse
21
+ import sys
22
+
23
+
24
+ def main() -> int:
25
+ parser = argparse.ArgumentParser(
26
+ description="Prepare ACE-Step dataset from a Hugging Face dataset (dataset.json + audio/)",
27
+ )
28
+ parser.add_argument("--dataset", required=True, metavar="NAME", help="Hugging Face dataset id")
29
+ parser.add_argument("--output-dir", required=True, metavar="DIR", help="Output directory for dataset.json and audio/")
30
+ parser.add_argument("--split", default="train", help="Dataset split (default: train)")
31
+ parser.add_argument("--config", default=None, help="Dataset config name if required")
32
+ parser.add_argument("--caption-column", default=None, help="Caption column (default: auto-detect)")
33
+ parser.add_argument("--audio-column", default=None, help="Audio column (default: auto-detect)")
34
+ parser.add_argument("--max-samples", type=int, default=None, help="Max samples to export (default: all)")
35
+ parser.add_argument("--audio-subdir", default="audio", help="Audio subdir under output-dir (default: audio)")
36
+ parser.add_argument("--json-filename", default="dataset.json", help="Output JSON filename (default: dataset.json)")
37
+ parser.add_argument("--trust-remote-code", action="store_true", help="Allow datasets with custom code")
38
+ args = parser.parse_args()
39
+
40
+ from acestep.training_v2.prepare_from_hf import prepare_from_hf
41
+
42
+ try:
43
+ result = prepare_from_hf(
44
+ dataset_name=args.dataset,
45
+ output_dir=args.output_dir,
46
+ split=args.split,
47
+ config=args.config,
48
+ caption_column=args.caption_column,
49
+ audio_column=args.audio_column,
50
+ max_samples=args.max_samples,
51
+ audio_subdir=args.audio_subdir,
52
+ json_filename=args.json_filename,
53
+ trust_remote_code=args.trust_remote_code,
54
+ )
55
+ except ImportError as e:
56
+ print(f"[FAIL] {e}", file=sys.stderr)
57
+ return 1
58
+ except Exception as e:
59
+ print(f"[FAIL] {e}", file=sys.stderr)
60
+ return 1
61
+
62
+ print(f"\n[OK] Prepared {result['num_samples']} samples")
63
+ print(f" dataset_json: {result['dataset_json']}")
64
+ print(f" audio_dir: {result['audio_dir']}")
65
+ print("\nNext: preprocess then train (see train.py preprocess / train.py fixed).")
66
+ return 0
67
+
68
+
69
+ if __name__ == "__main__":
70
+ sys.exit(main())
prepare_song_describer.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Prepare ACE-Step dataset from audio.zip + song_describer.csv, then optionally preprocess and train.
4
+
5
+ Downloads (or uses local) audio.zip and song_describer.csv, unzips audio, builds dataset.json
6
+ in the format expected by train.py preprocess, then runs preprocess and train.
7
+
8
+ Dataset source (Song Describer Dataset, SDD):
9
+ Zenodo: https://zenodo.org/records/10072001
10
+ - audio.zip (~3.3 GB, 706 recordings)
11
+ - song_describer.csv (~186 KB, ~1.1k captions)
12
+ Direct file URLs (use if the record page lists these names):
13
+ - https://zenodo.org/records/10072001/files/audio.zip
14
+ - https://zenodo.org/records/10072001/files/song_describer.csv
15
+
16
+ Usage:
17
+ python prepare_song_describer.py --audio-zip <URL_or_path> --csv <URL_or_path> --output-dir <DIR> [options]
18
+
19
+ Example (download from Zenodo then preprocess + train):
20
+ python prepare_song_describer.py --audio-zip "https://zenodo.org/records/10072001/files/audio.zip" --csv "https://zenodo.org/records/10072001/files/song_describer.csv" --output-dir ./data/song_describer --checkpoint-dir ./checkpoints --run-preprocess --run-train
21
+
22
+ Example (local files):
23
+ python prepare_song_describer.py --audio-zip ./audio.zip --csv ./song_describer.csv --output-dir ./data/song_describer --checkpoint-dir ./checkpoints --run-preprocess --run-train
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import csv
30
+ import json
31
+ import shutil
32
+ import subprocess
33
+ import sys
34
+ import urllib.request
35
+ from pathlib import Path
36
+
37
+
38
+ AUDIO_SUBDIR = "audio"
39
+ DATASET_JSON_NAME = "dataset.json"
40
+ DEFAULT_CAPTION_COLUMNS = ("caption", "description", "text", "title", "label")
41
+ DEFAULT_AUDIO_COLUMNS = ("filename", "path", "file", "audio", "id", "name")
42
+
43
+
44
+ def _is_url(s: str) -> bool:
45
+ return s.strip().startswith(("http://", "https://"))
46
+
47
+
48
+ def _download(url: str, dest: Path) -> None:
49
+ dest.parent.mkdir(parents=True, exist_ok=True)
50
+ req = urllib.request.Request(url, headers={"User-Agent": "ACE-Step/1.0"})
51
+ with urllib.request.urlopen(req) as resp:
52
+ dest.write_bytes(resp.read())
53
+ print(f"[INFO] Downloaded {url} -> {dest}", file=sys.stderr)
54
+
55
+
56
+ def _ensure_file(src: str, dest: Path) -> Path:
57
+ if _is_url(src):
58
+ _download(src, dest)
59
+ return dest
60
+ p = Path(src)
61
+ if not p.is_file():
62
+ raise FileNotFoundError(f"Not a file: {p}")
63
+ if p.resolve() != dest.resolve():
64
+ shutil.copy2(p, dest)
65
+ return dest
66
+
67
+
68
+ def _infer_csv_columns(reader: csv.DictReader) -> tuple[str, str]:
69
+ names = [c for c in reader.fieldnames or [] if c]
70
+ if not names:
71
+ raise ValueError("CSV has no header columns")
72
+ audio_col = None
73
+ for c in DEFAULT_AUDIO_COLUMNS:
74
+ if c in names:
75
+ audio_col = c
76
+ break
77
+ if not audio_col:
78
+ audio_col = names[0]
79
+ caption_col = None
80
+ for c in DEFAULT_CAPTION_COLUMNS:
81
+ if c in names:
82
+ caption_col = c
83
+ break
84
+ if not caption_col:
85
+ caption_col = names[1] if len(names) > 1 else names[0]
86
+ return audio_col, caption_col
87
+
88
+
89
+ def build_dataset_json(
90
+ csv_path: Path,
91
+ audio_dir: Path,
92
+ output_json_path: Path,
93
+ csv_audio_col: str | None = None,
94
+ csv_caption_col: str | None = None,
95
+ ) -> int:
96
+ audio_dir.mkdir(parents=True, exist_ok=True)
97
+ existing = {f.name for f in audio_dir.iterdir() if f.is_file()}
98
+
99
+ with open(csv_path, newline="", encoding="utf-8", errors="replace") as f:
100
+ reader = csv.DictReader(f)
101
+ audio_col, caption_col = _infer_csv_columns(reader)
102
+ if csv_audio_col:
103
+ audio_col = csv_audio_col
104
+ if csv_caption_col:
105
+ caption_col = csv_caption_col
106
+
107
+ samples: list[dict] = []
108
+ for row in reader:
109
+ raw_path = (row.get(audio_col) or "").strip()
110
+ if not raw_path:
111
+ continue
112
+ name = Path(raw_path).name
113
+ if name not in existing:
114
+ continue
115
+ caption = (row.get(caption_col) or "").strip() or "[Instrumental]"
116
+ if len(caption) > 512:
117
+ caption = caption[:512]
118
+ rel_audio = f"{AUDIO_SUBDIR}/{name}"
119
+ samples.append({
120
+ "filename": name,
121
+ "audio_path": rel_audio,
122
+ "caption": caption,
123
+ "lyrics": "[Instrumental]",
124
+ "genre": "",
125
+ "bpm": None,
126
+ "keyscale": "",
127
+ "timesignature": "",
128
+ "duration": 0,
129
+ "is_instrumental": True,
130
+ })
131
+
132
+ payload = {
133
+ "samples": samples,
134
+ "metadata": {"tag_position": "prepend", "genre_ratio": 0, "custom_tag": ""},
135
+ }
136
+ output_json_path.parent.mkdir(parents=True, exist_ok=True)
137
+ with open(output_json_path, "w", encoding="utf-8") as out:
138
+ json.dump(payload, out, indent=2)
139
+ print(f"[INFO] Wrote {len(samples)} samples to {output_json_path}", file=sys.stderr)
140
+ return len(samples)
141
+
142
+
143
+ def main() -> int:
144
+ parser = argparse.ArgumentParser(
145
+ description="Prepare dataset from audio.zip + song_describer.csv, then preprocess/train."
146
+ )
147
+ parser.add_argument("--audio-zip", required=True, help="URL or path to audio.zip")
148
+ parser.add_argument("--csv", required=True, help="URL or path to song_describer.csv")
149
+ parser.add_argument("--output-dir", required=True, help="Output directory (dataset.json + audio/)")
150
+ parser.add_argument("--csv-audio-col", default=None, help="CSV column for audio filename (default: auto)")
151
+ parser.add_argument("--csv-caption-col", default=None, help="CSV column for caption (default: auto)")
152
+ parser.add_argument("--checkpoint-dir", default=None, help="Checkpoint dir for preprocess/train")
153
+ parser.add_argument("--run-preprocess", action="store_true", help="Run train.py preprocess after preparing")
154
+ parser.add_argument("--run-train", action="store_true", help="Run train.py fixed after preprocessing")
155
+ parser.add_argument("--tensor-output", default=None, help="Dir for .pt tensors (default: <output-dir>/tensors)")
156
+ parser.add_argument("--lora-output", default=None, help="Dir for LoRA output (default: <output-dir>/lora_output)")
157
+ parser.add_argument("--model-variant", default="turbo", help="Model variant (default: turbo)")
158
+ args = parser.parse_args()
159
+
160
+ out_dir = Path(args.output_dir).resolve()
161
+ work = out_dir / "work"
162
+ work.mkdir(parents=True, exist_ok=True)
163
+
164
+ zip_path = work / "audio.zip"
165
+ csv_path = work / "song_describer.csv"
166
+ _ensure_file(args.audio_zip, zip_path)
167
+ _ensure_file(args.csv, csv_path)
168
+
169
+ audio_dir = out_dir / AUDIO_SUBDIR
170
+ if zip_path.is_file():
171
+ tmp_extract = work / "audio_extract"
172
+ tmp_extract.mkdir(parents=True, exist_ok=True)
173
+ shutil.unpack_archive(str(zip_path), str(tmp_extract))
174
+ audio_dir.mkdir(parents=True, exist_ok=True)
175
+ for f in tmp_extract.rglob("*"):
176
+ if f.is_file():
177
+ dest = audio_dir / f.name
178
+ if dest != f.resolve():
179
+ shutil.copy2(f, dest)
180
+ shutil.rmtree(tmp_extract, ignore_errors=True)
181
+ print(f"[INFO] Unpacked {zip_path} -> {audio_dir} (flattened)", file=sys.stderr)
182
+ else:
183
+ audio_dir.mkdir(parents=True, exist_ok=True)
184
+
185
+ dataset_json = out_dir / DATASET_JSON_NAME
186
+ n = build_dataset_json(
187
+ csv_path,
188
+ audio_dir,
189
+ dataset_json,
190
+ csv_audio_col=args.csv_audio_col,
191
+ csv_caption_col=args.csv_caption_col,
192
+ )
193
+ if n == 0:
194
+ print("[FAIL] No samples in dataset (CSV rows must match filenames in zip).", file=sys.stderr)
195
+ return 1
196
+
197
+ tensor_output = args.tensor_output or str(out_dir / "tensors")
198
+ lora_output = args.lora_output or str(out_dir / "lora_output")
199
+
200
+ if args.run_preprocess or args.run_train:
201
+ if not args.checkpoint_dir:
202
+ print("[FAIL] --checkpoint-dir required for --run-preprocess / --run-train.", file=sys.stderr)
203
+ return 1
204
+ train_py = Path(__file__).resolve().parent / "train.py"
205
+ if not train_py.is_file():
206
+ print(f"[FAIL] train.py not found: {train_py}", file=sys.stderr)
207
+ return 1
208
+
209
+ if args.run_preprocess:
210
+ cmd = [
211
+ sys.executable,
212
+ str(train_py),
213
+ "preprocess",
214
+ "--dataset-json", str(dataset_json),
215
+ "--tensor-output", tensor_output,
216
+ "--checkpoint-dir", args.checkpoint_dir,
217
+ "--model-variant", args.model_variant,
218
+ ]
219
+ print(f"[INFO] Running: {' '.join(cmd)}", file=sys.stderr)
220
+ rc = subprocess.call(cmd, cwd=str(train_py.parent))
221
+ if rc != 0:
222
+ return rc
223
+
224
+ if args.run_train:
225
+ cmd = [
226
+ sys.executable,
227
+ str(train_py),
228
+ "fixed",
229
+ "--dataset-dir", tensor_output,
230
+ "--checkpoint-dir", args.checkpoint_dir,
231
+ "--model-variant", args.model_variant,
232
+ "--output-dir", lora_output,
233
+ ]
234
+ print(f"[INFO] Running: {' '.join(cmd)}", file=sys.stderr)
235
+ rc = subprocess.call(cmd, cwd=str(train_py.parent))
236
+ if rc != 0:
237
+ return rc
238
+
239
+ print(f"\n[OK] Dataset: {dataset_json} ({n} samples)")
240
+ if not args.run_preprocess:
241
+ print(f" Preprocess: python train.py preprocess --dataset-json {dataset_json} --tensor-output {tensor_output} --checkpoint-dir <ckpt> --model-variant {args.model_variant}")
242
+ if args.run_preprocess and not args.run_train:
243
+ print(f" Train: python train.py fixed --dataset-dir {tensor_output} --checkpoint-dir <ckpt> --model-variant {args.model_variant} --output-dir {lora_output}")
244
+ return 0
245
+
246
+
247
+ if __name__ == "__main__":
248
+ sys.exit(main())
train.py CHANGED
@@ -61,7 +61,7 @@ def _has_subcommand() -> bool:
61
  args = sys.argv[1:]
62
  if "--help" in args or "-h" in args:
63
  return True # let argparse handle help
64
- known = {"vanilla", "fixed", "estimate"}
65
  return bool(known & set(args))
66
 
67
 
@@ -89,7 +89,10 @@ def _dispatch(args) -> int:
89
 
90
  sub = args.subcommand
91
 
92
- # All subcommands need path validation
 
 
 
93
  if not validate_paths(args):
94
  return 1
95
 
@@ -202,6 +205,52 @@ def _run_preprocess(args) -> int:
202
  return 0
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def _run_estimate(args) -> int:
206
  """Run gradient sensitivity estimation."""
207
  import json as _json
 
61
  args = sys.argv[1:]
62
  if "--help" in args or "-h" in args:
63
  return True # let argparse handle help
64
+ known = {"vanilla", "fixed", "estimate", "from-hf"}
65
  return bool(known & set(args))
66
 
67
 
 
89
 
90
  sub = args.subcommand
91
 
92
+ if sub == "from-hf":
93
+ return _run_from_hf(args)
94
+
95
+ # All other subcommands need path validation
96
  if not validate_paths(args):
97
  return 1
98
 
 
205
  return 0
206
 
207
 
208
+ def _run_from_hf(args) -> int:
209
+ """Prepare dataset from a Hugging Face dataset."""
210
+ from acestep.training_v2.prepare_from_hf import prepare_from_hf
211
+
212
+ out_dir = getattr(args, "output_dir", None)
213
+ if not out_dir:
214
+ print("[FAIL] from-hf requires --output-dir.", file=sys.stderr)
215
+ return 1
216
+
217
+ print("\n" + "=" * 60)
218
+ print(" Prepare from Hugging Face dataset")
219
+ print("=" * 60)
220
+ print(f" Dataset: {args.dataset}")
221
+ print(f" Split: {getattr(args, 'split', 'train')}")
222
+ print(f" Output: {out_dir}")
223
+ print("=" * 60)
224
+
225
+ try:
226
+ result = prepare_from_hf(
227
+ dataset_name=args.dataset,
228
+ output_dir=out_dir,
229
+ split=getattr(args, "split", "train"),
230
+ config=getattr(args, "config", None),
231
+ caption_column=getattr(args, "caption_column", None),
232
+ audio_column=getattr(args, "audio_column", None),
233
+ max_samples=getattr(args, "max_samples", None),
234
+ audio_subdir=getattr(args, "audio_subdir", "audio"),
235
+ json_filename=getattr(args, "json_filename", "dataset.json"),
236
+ trust_remote_code=getattr(args, "trust_remote_code", False),
237
+ )
238
+ except Exception as exc:
239
+ print(f"[FAIL] {exc}", file=sys.stderr)
240
+ logger.exception("from-hf error")
241
+ return 1
242
+
243
+ print(f"\n[OK] Prepared {result['num_samples']} samples:")
244
+ print(f" dataset_json: {result['dataset_json']}")
245
+ print(f" audio_dir: {result['audio_dir']}")
246
+ print("\nNext steps:")
247
+ print(f" 1. Preprocess: python train.py preprocess --dataset-json {result['dataset_json']} \\")
248
+ print(f" --tensor-output <pt_dir> --checkpoint-dir <ckpt> --model-variant turbo")
249
+ print(f" 2. Train: python train.py fixed --dataset-dir <pt_dir> --checkpoint-dir <ckpt> \\")
250
+ print(f" --model-variant turbo --output-dir <lora_output>")
251
+ return 0
252
+
253
+
254
  def _run_estimate(args) -> int:
255
  """Run gradient sensitivity estimation."""
256
  import json as _json