codemichaeld commited on
Commit
7153add
Β·
verified Β·
1 Parent(s): 0d6c60b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -21
app.py CHANGED
@@ -62,6 +62,37 @@ def convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progres
62
  except Exception as e:
63
  return False, str(e)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # --- Source download helper ---
66
  def download_safetensors_file(
67
  source_type,
@@ -74,15 +105,14 @@ def download_safetensors_file(
74
  temp_dir = tempfile.mkdtemp()
75
  try:
76
  if source_type == "huggingface":
77
- clean_url = repo_url.strip().rstrip("/")
78
- if "huggingface.co" not in clean_url:
79
- raise ValueError("Invalid Hugging Face URL")
80
- src_repo_id = clean_url.replace("https://huggingface.co/", "")
81
  safetensors_path = hf_hub_download(
82
- repo_id=src_repo_id,
83
  filename=filename,
 
84
  cache_dir=temp_dir,
85
- token=hf_token
 
86
  )
87
  elif source_type == "modelscope":
88
  if not MODELScope_AVAILABLE:
@@ -148,7 +178,6 @@ def upload_to_target(
148
  api = ModelScopeApi()
149
  if modelscope_token:
150
  api.login(modelscope_token)
151
- # ModelScope requires model_type and license
152
  api.push_model(
153
  model_id=new_repo_id,
154
  model_dir=output_dir,
@@ -190,8 +219,7 @@ def process_and_upload_fp8(
190
  output_dir = tempfile.mkdtemp()
191
 
192
  try:
193
- # Authenticate & download
194
- progress(0.05, desc="Authenticating and downloading...")
195
  safetensors_path, temp_dir = download_safetensors_file(
196
  source_type=source_type,
197
  repo_url=repo_url,
@@ -202,12 +230,10 @@ def process_and_upload_fp8(
202
  )
203
  progress(0.25, desc="Download complete.")
204
 
205
- # Convert
206
  success, msg = convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress)
207
  if not success:
208
  return None, f"❌ Conversion failed: {msg}", ""
209
 
210
- # Upload
211
  progress(0.92, desc="Uploading model...")
212
  repo_url_final = upload_to_target(
213
  target_type=target_type,
@@ -220,7 +246,6 @@ def process_and_upload_fp8(
220
  progress=progress
221
  )
222
 
223
- # README
224
  base_name = os.path.splitext(safetensors_filename)[0]
225
  fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
226
  readme = f"""---
@@ -241,14 +266,11 @@ File: `{safetensors_filename}` β†’ `{fp8_filename}`
241
 
242
  Quantization: **FP8 ({fp8_format.upper()})**
243
  Converted on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
244
-
245
- > ⚠️ Requires PyTorch β‰₯ 2.1 and compatible hardware for FP8 acceleration.
246
  """
247
  readme_path = os.path.join(output_dir, "README.md")
248
  with open(readme_path, "w") as f:
249
  f.write(readme)
250
 
251
- # Re-upload README if needed (for ModelScope, already included; for HF, upload separately)
252
  if target_type == "huggingface":
253
  HfApi(token=hf_token).upload_file(
254
  path_or_fileobj=readme_path,
@@ -277,6 +299,7 @@ Source: {source_type.title()} β†’ Target: {target_type.title()}
277
  with gr.Blocks(title="Safetensors β†’ FP8 Pruner (HF + ModelScope)") as demo:
278
  gr.Markdown("# πŸ”„ Safetensors to FP8 Pruner")
279
  gr.Markdown("Convert `.safetensors` models to **FP8** and upload to **Hugging Face** or **ModelScope**.")
 
280
 
281
  with gr.Row():
282
  with gr.Column():
@@ -287,18 +310,17 @@ with gr.Blocks(title="Safetensors β†’ FP8 Pruner (HF + ModelScope)") as demo:
287
  )
288
  repo_url = gr.Textbox(
289
  label="Source Repository URL",
290
- placeholder="e.g., https://huggingface.co/Yabo/FramePainter OR your-modelscope-id",
291
- info="Hugging Face URL or ModelScope model ID"
292
  )
293
  safetensors_filename = gr.Textbox(
294
  label="Safetensors Filename",
295
- placeholder="unet_diffusion_pytorch_model.safetensors"
296
  )
297
  fp8_format = gr.Radio(
298
  choices=["e4m3fn", "e5m2"],
299
  value="e5m2",
300
- label="FP8 Format",
301
- info="E5M2: wider range; E4M3FN: better near-zero precision"
302
  )
303
  hf_token = gr.Textbox(
304
  label="Hugging Face Token (if using HF)",
@@ -346,7 +368,7 @@ with gr.Blocks(title="Safetensors β†’ FP8 Pruner (HF + ModelScope)") as demo:
346
 
347
  gr.Examples(
348
  examples=[
349
- ["huggingface", "https://huggingface.co/Yabo/FramePainter", "unet_diffusion_pytorch_model.safetensors", "e5m2", "huggingface"]
350
  ],
351
  inputs=[source_type, repo_url, safetensors_filename, fp8_format, target_type]
352
  )
 
62
  except Exception as e:
63
  return False, str(e)
64
 
65
+ # --- Parse HF URL with optional subfolder ---
66
+ def parse_hf_url(url):
67
+ """
68
+ Parses a Hugging Face URL like:
69
+ - https://huggingface.co/username/repo
70
+ - https://huggingface.co/username/repo/tree/main/subfolder
71
+ Returns (repo_id, subfolder)
72
+ """
73
+ url = url.strip().rstrip("/")
74
+ if not url.startswith("https://huggingface.co/"):
75
+ raise ValueError("URL must start with https://huggingface.co/")
76
+ path = url.replace("https://huggingface.co/", "")
77
+ parts = path.split("/")
78
+
79
+ if len(parts) < 2:
80
+ raise ValueError("Invalid repo format")
81
+
82
+ # repo_id is always first two parts
83
+ repo_id = "/".join(parts[:2])
84
+
85
+ # Check if "/tree/branch/" is present
86
+ subfolder = ""
87
+ if len(parts) > 3 and parts[2] == "tree":
88
+ # everything after branch is subfolder
89
+ subfolder = "/".join(parts[4:]) if len(parts) > 4 else ""
90
+ elif len(parts) > 2:
91
+ # old style: username/repo/subfolder (not standard, but support)
92
+ subfolder = "/".join(parts[2:])
93
+
94
+ return repo_id, subfolder
95
+
96
  # --- Source download helper ---
97
  def download_safetensors_file(
98
  source_type,
 
105
  temp_dir = tempfile.mkdtemp()
106
  try:
107
  if source_type == "huggingface":
108
+ repo_id, subfolder = parse_hf_url(repo_url)
 
 
 
109
  safetensors_path = hf_hub_download(
110
+ repo_id=repo_id,
111
  filename=filename,
112
+ subfolder=subfolder or None,
113
  cache_dir=temp_dir,
114
+ token=hf_token,
115
+ resume_download=True
116
  )
117
  elif source_type == "modelscope":
118
  if not MODELScope_AVAILABLE:
 
178
  api = ModelScopeApi()
179
  if modelscope_token:
180
  api.login(modelscope_token)
 
181
  api.push_model(
182
  model_id=new_repo_id,
183
  model_dir=output_dir,
 
219
  output_dir = tempfile.mkdtemp()
220
 
221
  try:
222
+ progress(0.05, desc="Parsing URL and downloading...")
 
223
  safetensors_path, temp_dir = download_safetensors_file(
224
  source_type=source_type,
225
  repo_url=repo_url,
 
230
  )
231
  progress(0.25, desc="Download complete.")
232
 
 
233
  success, msg = convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress)
234
  if not success:
235
  return None, f"❌ Conversion failed: {msg}", ""
236
 
 
237
  progress(0.92, desc="Uploading model...")
238
  repo_url_final = upload_to_target(
239
  target_type=target_type,
 
246
  progress=progress
247
  )
248
 
 
249
  base_name = os.path.splitext(safetensors_filename)[0]
250
  fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
251
  readme = f"""---
 
266
 
267
  Quantization: **FP8 ({fp8_format.upper()})**
268
  Converted on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
 
 
269
  """
270
  readme_path = os.path.join(output_dir, "README.md")
271
  with open(readme_path, "w") as f:
272
  f.write(readme)
273
 
 
274
  if target_type == "huggingface":
275
  HfApi(token=hf_token).upload_file(
276
  path_or_fileobj=readme_path,
 
299
  with gr.Blocks(title="Safetensors β†’ FP8 Pruner (HF + ModelScope)") as demo:
300
  gr.Markdown("# πŸ”„ Safetensors to FP8 Pruner")
301
  gr.Markdown("Convert `.safetensors` models to **FP8** and upload to **Hugging Face** or **ModelScope**.")
302
+ gr.Markdown("Supports subfolders: e.g., `https://huggingface.co/lixiaowen/diffuEraser/tree/main/brushnet`")
303
 
304
  with gr.Row():
305
  with gr.Column():
 
310
  )
311
  repo_url = gr.Textbox(
312
  label="Source Repository URL",
313
+ placeholder="https://huggingface.co/lixiaowen/diffuEraser/tree/main/brushnet",
314
+ info="Full URL including subfolder (if any)"
315
  )
316
  safetensors_filename = gr.Textbox(
317
  label="Safetensors Filename",
318
+ placeholder="diffusion_pytorch_model.safetensors"
319
  )
320
  fp8_format = gr.Radio(
321
  choices=["e4m3fn", "e5m2"],
322
  value="e5m2",
323
+ label="FP8 Format"
 
324
  )
325
  hf_token = gr.Textbox(
326
  label="Hugging Face Token (if using HF)",
 
368
 
369
  gr.Examples(
370
  examples=[
371
+ ["huggingface", "https://huggingface.co/lixiaowen/diffuEraser/tree/main/brushnet", "diffusion_pytorch_model.safetensors", "e5m2", "huggingface"]
372
  ],
373
  inputs=[source_type, repo_url, safetensors_filename, fp8_format, target_type]
374
  )