File size: 3,393 Bytes
16d51ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
import argparse
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

# Known edit types and how many parquet files exist for each
EDIT_TYPES = {
    "color": 1984,
    "motion change": 128,
    "style": 1600,
    "replace": 1566,
    "remove": 1388,
    "add": 1213,
    "background change": 1091,
}

BASE_URL = "https://huggingface.co/datasets/WeiChow/CrispEdit-2M/resolve/main/data"
MAX_WORKERS = 8  # parallel downloads


def parse_args():
    parser = argparse.ArgumentParser(
        description="Download CrispEdit-2M parquet files with wget in parallel."
    )
    parser.add_argument(
        "-n",
        type=int,
        default=None,
        help="Number of parquet files per edit type (default: all available)",
    )
    parser.add_argument(
        "-e",
        "--edit-type",
        choices=EDIT_TYPES.keys(),
        nargs="+",
        help="Edit types to download (default: all types)",
    )
    parser.add_argument(
        "-o",
        "--output-dir",
        default=".",
        help="Directory to save downloaded files (default: current directory)",
    )
    return parser.parse_args()


def validate_args(args):
    if args.n is not None and args.n <= 0:
        print("Error: -n must be a positive integer.", file=sys.stderr)
        sys.exit(1)

    selected_types = args.edit_type or list(EDIT_TYPES.keys())

    # Ensure n is not larger than available parquet count for any selected type
    if args.n is not None:
        for et in selected_types:
            max_files = EDIT_TYPES[et]
            if args.n > max_files:
                print(
                    f"Error: requested -n={args.n} but edit type '{et}' "
                    f"only has {max_files} parquet files.",
                    file=sys.stderr,
                )
                sys.exit(1)

    return selected_types


def download_file(url: str, dest: Path):
    # if dest.exists():
    #     print(f"[skip] {dest} already exists")
    #     return

    cmd = ["wget", "-O", str(dest), "--progress=bar:force", "-c", url]
    print(" ".join(cmd))
    result = subprocess.run(cmd)

    if result.returncode != 0:
        raise RuntimeError(f"wget failed for {url} (exit code {result.returncode})")


def main():
    args = parse_args()
    edit_types = validate_args(args)

    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    jobs = []
    # Build download jobs: one per (edit_type, index)
    for et in edit_types:
        num_files = args.n if args.n is not None else EDIT_TYPES[et]
        for idx in range(num_files):
            filename = f"{et}_{idx:05d}.parquet"
            url = f"{BASE_URL}/{filename}"
            dest = out_dir / filename
            jobs.append((url, dest))

    print(f"Downloading {len(jobs)} files for edit types: {', '.join(edit_types)}")

    # Parallel downloads
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        future_to_job = {executor.submit(download_file, url, dest): (url, dest) for url, dest in jobs}

        for future in as_completed(future_to_job):
            url, dest = future_to_job[future]
            try:
                future.result()
            except Exception as e:
                print(f"[error] {url} -> {dest}: {e}", file=sys.stderr)


if __name__ == "__main__":
    main()