codemichaeld commited on
Commit
0d6c60b
Β·
verified Β·
1 Parent(s): 7f615fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -79
app.py CHANGED
@@ -10,16 +10,20 @@ from huggingface_hub import HfApi, hf_hub_download
10
  from safetensors.torch import load_file, save_file
11
  import torch
12
 
13
- # --- Conversion Function: Safetensors β†’ FP8 Safetensors (E4M3FN or E5M2) ---
 
 
 
 
 
 
 
 
 
14
  def convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress=gr.Progress()):
15
- """
16
- Loads a .safetensors file and saves a pruned FP8 version.
17
- fp8_format: 'e4m3fn' or 'e5m2'
18
- """
19
  progress(0.1, desc="Starting FP8 conversion...")
20
 
21
  try:
22
- # Read metadata
23
  def read_safetensors_metadata(path):
24
  with open(path, 'rb') as f:
25
  header_size = int.from_bytes(f.read(8), 'little')
@@ -30,28 +34,23 @@ def convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progres
30
  metadata = read_safetensors_metadata(safetensors_path)
31
  progress(0.3, desc="Loaded model metadata.")
32
 
33
- # Load state dict
34
  state_dict = load_file(safetensors_path)
35
  progress(0.5, desc="Loaded model weights.")
36
 
37
- # Select FP8 dtype
38
  if fp8_format == "e5m2":
39
  fp8_dtype = torch.float8_e5m2
40
- else: # default to e4m3fn
41
  fp8_dtype = torch.float8_e4m3fn
42
 
43
- # Convert to FP8
44
  sd_pruned = {}
45
  total = len(state_dict)
46
  for i, key in enumerate(state_dict):
47
  progress(0.5 + 0.4 * (i / total), desc=f"Converting tensor {i+1}/{total} to FP8 ({fp8_format})...")
48
- # Only convert float tensors
49
  if state_dict[key].dtype in [torch.float16, torch.float32, torch.bfloat16]:
50
  sd_pruned[key] = state_dict[key].to(fp8_dtype)
51
  else:
52
- sd_pruned[key] = state_dict[key] # keep non-float as-is (e.g., int for embeddings)
53
 
54
- # Save FP8 safetensors
55
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
56
  output_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
57
  save_file(sd_pruned, output_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
@@ -63,56 +62,165 @@ def convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progres
63
  except Exception as e:
64
  return False, str(e)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # --- Main Processing Function ---
67
- def process_and_upload_fp8(repo_url, safetensors_filename, fp8_format, hf_token, new_repo_id, private_repo, progress=gr.Progress()):
68
- if not all([repo_url, safetensors_filename, fp8_format, hf_token, new_repo_id]):
69
- return None, "❌ Error: Please fill in all fields.", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id):
72
- return None, "❌ Error: Invalid repository ID format. Use 'username/model-name'.", ""
73
 
74
- temp_dir = tempfile.mkdtemp()
75
  output_dir = tempfile.mkdtemp()
76
 
77
  try:
78
- # Authenticate
79
- progress(0.05, desc="Logging into Hugging Face...")
80
- api = HfApi(token=hf_token)
81
- user_info = api.whoami()
82
- user_name = user_info['name']
83
- progress(0.1, desc=f"Logged in as {user_name}.")
84
-
85
- # Parse source repo
86
- clean_url = repo_url.strip().rstrip("/")
87
- if "huggingface.co" not in clean_url:
88
- return None, "❌ Source must be a Hugging Face model repo.", ""
89
- src_repo_id = clean_url.replace("https://huggingface.co/", "")
90
-
91
- # Download specified safetensors file
92
- progress(0.15, desc=f"Downloading {safetensors_filename}...")
93
- safetensors_path = hf_hub_download(
94
- repo_id=src_repo_id,
95
  filename=safetensors_filename,
96
- cache_dir=temp_dir,
97
- token=hf_token
 
98
  )
99
  progress(0.25, desc="Download complete.")
100
 
101
- # Convert to FP8
102
  success, msg = convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress)
103
  if not success:
104
  return None, f"❌ Conversion failed: {msg}", ""
105
 
106
- # Create new repo
107
- progress(0.92, desc="Creating new repository...")
108
- api.create_repo(
109
- repo_id=new_repo_id,
110
- private=private_repo,
111
- repo_type="model",
112
- exist_ok=True
 
 
 
 
113
  )
114
 
115
- # Generate README
116
  base_name = os.path.splitext(safetensors_filename)[0]
117
  fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
118
  readme = f"""---
@@ -128,77 +236,90 @@ tags:
128
 
129
  # FP8 Pruned Model ({fp8_format.upper()})
130
 
131
- Converted from: [`{src_repo_id}`](https://huggingface.co/{src_repo_id})
132
  File: `{safetensors_filename}` β†’ `{fp8_filename}`
133
 
134
  Quantization: **FP8 ({fp8_format.upper()})**
135
- Converted by: {user_name}
136
- Date: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
137
 
138
- > ⚠️ FP8 models require PyTorch β‰₯ 2.1 and compatible hardware (e.g., NVIDIA Ada/Hopper) for full acceleration. May fall back to FP16 on older GPUs.
139
  """
140
- with open(os.path.join(output_dir, "README.md"), "w") as f:
 
141
  f.write(readme)
142
 
143
- # Upload
144
- progress(0.95, desc="Uploading to Hugging Face Hub...")
145
- api.upload_folder(
146
- repo_id=new_repo_id,
147
- folder_path=output_dir,
148
- repo_type="model",
149
- token=hf_token,
150
- commit_message=f"Upload FP8 ({fp8_format}) pruned safetensors model"
151
- )
152
 
153
  progress(1.0, desc="βœ… Done!")
154
  result_html = f"""
155
  βœ… Success!
156
- Your FP8 ({fp8_format}) model is uploaded to: [{new_repo_id}](https://huggingface.co/{new_repo_id})
157
- Visibility: {'Private' if private_repo else 'Public'}
158
  """
159
  return gr.HTML(result_html), "βœ… FP8 conversion and upload successful!", ""
160
 
161
  except Exception as e:
162
  return None, f"❌ Error: {str(e)}", ""
163
  finally:
164
- shutil.rmtree(temp_dir, ignore_errors=True)
 
165
  shutil.rmtree(output_dir, ignore_errors=True)
166
 
167
  # --- Gradio UI ---
168
- with gr.Blocks(title="Safetensors β†’ FP8 Pruner") as demo:
169
  gr.Markdown("# πŸ”„ Safetensors to FP8 Pruner")
170
- gr.Markdown("Converts any `.safetensors` file from a Hugging Face model repo to **FP8 (E4M3FN or E5M2)** for compact storage and faster inference.")
171
 
172
  with gr.Row():
173
  with gr.Column():
 
 
 
 
 
174
  repo_url = gr.Textbox(
175
- label="Source Model Repository URL",
176
- placeholder="https://huggingface.co/Yabo/FramePainter",
177
- info="Hugging Face model repo containing your safetensors file"
178
  )
179
  safetensors_filename = gr.Textbox(
180
  label="Safetensors Filename",
181
- placeholder="unet_diffusion_pytorch_model.safetensors",
182
- info="Name of the .safetensors file in the repo"
183
  )
184
  fp8_format = gr.Radio(
185
  choices=["e4m3fn", "e5m2"],
186
  value="e5m2",
187
  label="FP8 Format",
188
- info="E5M2 has wider dynamic range; E4M3FN has higher precision near zero."
189
  )
190
  hf_token = gr.Textbox(
191
- label="Hugging Face Token",
 
 
 
 
192
  type="password",
193
- info="Write-access token from https://huggingface.co/settings/tokens"
194
  )
195
  with gr.Column():
 
 
 
 
 
196
  new_repo_id = gr.Textbox(
197
  label="New Repository ID",
198
- placeholder="your-username/my-model-fp8",
199
- info="Format: username/model-name"
200
  )
201
- private_repo = gr.Checkbox(label="Make Private", value=False)
202
 
203
  convert_btn = gr.Button("πŸš€ Convert & Upload", variant="primary")
204
 
@@ -208,16 +329,26 @@ with gr.Blocks(title="Safetensors β†’ FP8 Pruner") as demo:
208
 
209
  convert_btn.click(
210
  fn=process_and_upload_fp8,
211
- inputs=[repo_url, safetensors_filename, fp8_format, hf_token, new_repo_id, private_repo],
 
 
 
 
 
 
 
 
 
 
212
  outputs=[repo_link_output, status_output],
213
  show_progress=True
214
  )
215
 
216
  gr.Examples(
217
  examples=[
218
- ["https://huggingface.co/Yabo/FramePainter", "unet_diffusion_pytorch_model.safetensors", "e5m2"]
219
  ],
220
- inputs=[repo_url, safetensors_filename, fp8_format]
221
  )
222
 
223
  demo.launch()
 
10
  from safetensors.torch import load_file, save_file
11
  import torch
12
 
13
+ # Optional ModelScope integration
14
+ try:
15
+ from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download
16
+ from modelscope.hub.file_download import model_file_download as ms_file_download
17
+ from modelscope.hub.api import HubApi as ModelScopeApi
18
+ MODELScope_AVAILABLE = True
19
+ except ImportError:
20
+ MODELScope_AVAILABLE = False
21
+
22
+ # --- Conversion Function: Safetensors β†’ FP8 Safetensors ---
23
  def convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress=gr.Progress()):
 
 
 
 
24
  progress(0.1, desc="Starting FP8 conversion...")
25
 
26
  try:
 
27
  def read_safetensors_metadata(path):
28
  with open(path, 'rb') as f:
29
  header_size = int.from_bytes(f.read(8), 'little')
 
34
  metadata = read_safetensors_metadata(safetensors_path)
35
  progress(0.3, desc="Loaded model metadata.")
36
 
 
37
  state_dict = load_file(safetensors_path)
38
  progress(0.5, desc="Loaded model weights.")
39
 
 
40
  if fp8_format == "e5m2":
41
  fp8_dtype = torch.float8_e5m2
42
+ else:
43
  fp8_dtype = torch.float8_e4m3fn
44
 
 
45
  sd_pruned = {}
46
  total = len(state_dict)
47
  for i, key in enumerate(state_dict):
48
  progress(0.5 + 0.4 * (i / total), desc=f"Converting tensor {i+1}/{total} to FP8 ({fp8_format})...")
 
49
  if state_dict[key].dtype in [torch.float16, torch.float32, torch.bfloat16]:
50
  sd_pruned[key] = state_dict[key].to(fp8_dtype)
51
  else:
52
+ sd_pruned[key] = state_dict[key]
53
 
 
54
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
55
  output_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
56
  save_file(sd_pruned, output_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
 
62
  except Exception as e:
63
  return False, str(e)
64
 
65
+ # --- Source download helper ---
66
+ def download_safetensors_file(
67
+ source_type,
68
+ repo_url,
69
+ filename,
70
+ hf_token=None,
71
+ modelscope_token=None,
72
+ progress=gr.Progress()
73
+ ):
74
+ temp_dir = tempfile.mkdtemp()
75
+ try:
76
+ if source_type == "huggingface":
77
+ clean_url = repo_url.strip().rstrip("/")
78
+ if "huggingface.co" not in clean_url:
79
+ raise ValueError("Invalid Hugging Face URL")
80
+ src_repo_id = clean_url.replace("https://huggingface.co/", "")
81
+ safetensors_path = hf_hub_download(
82
+ repo_id=src_repo_id,
83
+ filename=filename,
84
+ cache_dir=temp_dir,
85
+ token=hf_token
86
+ )
87
+ elif source_type == "modelscope":
88
+ if not MODELScope_AVAILABLE:
89
+ raise ImportError("ModelScope not installed. Install with: pip install modelscope")
90
+ clean_url = repo_url.strip().rstrip("/")
91
+ if "modelscope.cn" in clean_url:
92
+ src_repo_id = "/".join(clean_url.split("/")[-2:])
93
+ else:
94
+ src_repo_id = repo_url.strip()
95
+ if modelscope_token:
96
+ os.environ["MODELSCOPE_CACHE"] = temp_dir
97
+ safetensors_path = ms_file_download(
98
+ model_id=src_repo_id,
99
+ file_path=filename,
100
+ token=modelscope_token
101
+ )
102
+ else:
103
+ safetensors_path = ms_file_download(
104
+ model_id=src_repo_id,
105
+ file_path=filename
106
+ )
107
+ else:
108
+ raise ValueError("Unknown source type")
109
+
110
+ return safetensors_path, temp_dir
111
+ except Exception as e:
112
+ shutil.rmtree(temp_dir, ignore_errors=True)
113
+ raise e
114
+
115
+ # --- Upload helper ---
116
+ def upload_to_target(
117
+ target_type,
118
+ new_repo_id,
119
+ output_dir,
120
+ fp8_format,
121
+ hf_token=None,
122
+ modelscope_token=None,
123
+ private_repo=False,
124
+ progress=gr.Progress()
125
+ ):
126
+ if target_type == "huggingface":
127
+ if not hf_token:
128
+ raise ValueError("Hugging Face token required")
129
+ api = HfApi(token=hf_token)
130
+ api.create_repo(
131
+ repo_id=new_repo_id,
132
+ private=private_repo,
133
+ repo_type="model",
134
+ exist_ok=True
135
+ )
136
+ api.upload_folder(
137
+ repo_id=new_repo_id,
138
+ folder_path=output_dir,
139
+ repo_type="model",
140
+ token=hf_token,
141
+ commit_message=f"Upload FP8 ({fp8_format}) model"
142
+ )
143
+ return f"https://huggingface.co/{new_repo_id}"
144
+
145
+ elif target_type == "modelscope":
146
+ if not MODELScope_AVAILABLE:
147
+ raise ImportError("ModelScope not installed")
148
+ api = ModelScopeApi()
149
+ if modelscope_token:
150
+ api.login(modelscope_token)
151
+ # ModelScope requires model_type and license
152
+ api.push_model(
153
+ model_id=new_repo_id,
154
+ model_dir=output_dir,
155
+ commit_message=f"Upload FP8 ({fp8_format}) model"
156
+ )
157
+ return f"https://modelscope.cn/models/{new_repo_id}"
158
+
159
+ else:
160
+ raise ValueError("Unknown target type")
161
+
162
  # --- Main Processing Function ---
163
+ def process_and_upload_fp8(
164
+ source_type,
165
+ repo_url,
166
+ safetensors_filename,
167
+ fp8_format,
168
+ target_type,
169
+ new_repo_id,
170
+ hf_token,
171
+ modelscope_token,
172
+ private_repo,
173
+ progress=gr.Progress()
174
+ ):
175
+ required_fields = [repo_url, safetensors_filename, new_repo_id]
176
+ if source_type == "huggingface":
177
+ required_fields.append(hf_token)
178
+ if target_type == "huggingface":
179
+ required_fields.append(hf_token)
180
+ if target_type == "modelscope" and modelscope_token:
181
+ required_fields.append(modelscope_token)
182
+
183
+ if not all(required_fields):
184
+ return None, "❌ Error: Please fill in all required fields.", ""
185
 
186
  if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id):
187
+ return None, "❌ Invalid repository ID format. Use 'username/model-name'.", ""
188
 
189
+ temp_dir = None
190
  output_dir = tempfile.mkdtemp()
191
 
192
  try:
193
+ # Authenticate & download
194
+ progress(0.05, desc="Authenticating and downloading...")
195
+ safetensors_path, temp_dir = download_safetensors_file(
196
+ source_type=source_type,
197
+ repo_url=repo_url,
 
 
 
 
 
 
 
 
 
 
 
 
198
  filename=safetensors_filename,
199
+ hf_token=hf_token,
200
+ modelscope_token=modelscope_token,
201
+ progress=progress
202
  )
203
  progress(0.25, desc="Download complete.")
204
 
205
+ # Convert
206
  success, msg = convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress)
207
  if not success:
208
  return None, f"❌ Conversion failed: {msg}", ""
209
 
210
+ # Upload
211
+ progress(0.92, desc="Uploading model...")
212
+ repo_url_final = upload_to_target(
213
+ target_type=target_type,
214
+ new_repo_id=new_repo_id,
215
+ output_dir=output_dir,
216
+ fp8_format=fp8_format,
217
+ hf_token=hf_token,
218
+ modelscope_token=modelscope_token,
219
+ private_repo=private_repo,
220
+ progress=progress
221
  )
222
 
223
+ # README
224
  base_name = os.path.splitext(safetensors_filename)[0]
225
  fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
226
  readme = f"""---
 
236
 
237
  # FP8 Pruned Model ({fp8_format.upper()})
238
 
239
+ Converted from: `{repo_url}`
240
  File: `{safetensors_filename}` β†’ `{fp8_filename}`
241
 
242
  Quantization: **FP8 ({fp8_format.upper()})**
243
+ Converted on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
 
244
 
245
+ > ⚠️ Requires PyTorch β‰₯ 2.1 and compatible hardware for FP8 acceleration.
246
  """
247
+ readme_path = os.path.join(output_dir, "README.md")
248
+ with open(readme_path, "w") as f:
249
  f.write(readme)
250
 
251
+ # Re-upload README if needed (for ModelScope, already included; for HF, upload separately)
252
+ if target_type == "huggingface":
253
+ HfApi(token=hf_token).upload_file(
254
+ path_or_fileobj=readme_path,
255
+ path_in_repo="README.md",
256
+ repo_id=new_repo_id,
257
+ repo_type="model",
258
+ token=hf_token
259
+ )
260
 
261
  progress(1.0, desc="βœ… Done!")
262
  result_html = f"""
263
  βœ… Success!
264
+ Your FP8 model is uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
265
+ Source: {source_type.title()} β†’ Target: {target_type.title()}
266
  """
267
  return gr.HTML(result_html), "βœ… FP8 conversion and upload successful!", ""
268
 
269
  except Exception as e:
270
  return None, f"❌ Error: {str(e)}", ""
271
  finally:
272
+ if temp_dir:
273
+ shutil.rmtree(temp_dir, ignore_errors=True)
274
  shutil.rmtree(output_dir, ignore_errors=True)
275
 
276
  # --- Gradio UI ---
277
+ with gr.Blocks(title="Safetensors β†’ FP8 Pruner (HF + ModelScope)") as demo:
278
  gr.Markdown("# πŸ”„ Safetensors to FP8 Pruner")
279
+ gr.Markdown("Convert `.safetensors` models to **FP8** and upload to **Hugging Face** or **ModelScope**.")
280
 
281
  with gr.Row():
282
  with gr.Column():
283
+ source_type = gr.Radio(
284
+ choices=["huggingface", "modelscope"],
285
+ value="huggingface",
286
+ label="Source Platform"
287
+ )
288
  repo_url = gr.Textbox(
289
+ label="Source Repository URL",
290
+ placeholder="e.g., https://huggingface.co/Yabo/FramePainter OR your-modelscope-id",
291
+ info="Hugging Face URL or ModelScope model ID"
292
  )
293
  safetensors_filename = gr.Textbox(
294
  label="Safetensors Filename",
295
+ placeholder="unet_diffusion_pytorch_model.safetensors"
 
296
  )
297
  fp8_format = gr.Radio(
298
  choices=["e4m3fn", "e5m2"],
299
  value="e5m2",
300
  label="FP8 Format",
301
+ info="E5M2: wider range; E4M3FN: better near-zero precision"
302
  )
303
  hf_token = gr.Textbox(
304
+ label="Hugging Face Token (if using HF)",
305
+ type="password"
306
+ )
307
+ modelscope_token = gr.Textbox(
308
+ label="ModelScope Token (optional)",
309
  type="password",
310
+ visible=MODELScope_AVAILABLE
311
  )
312
  with gr.Column():
313
+ target_type = gr.Radio(
314
+ choices=["huggingface", "modelscope"],
315
+ value="huggingface",
316
+ label="Target Platform"
317
+ )
318
  new_repo_id = gr.Textbox(
319
  label="New Repository ID",
320
+ placeholder="your-username/my-model-fp8"
 
321
  )
322
+ private_repo = gr.Checkbox(label="Make Private (HF only)", value=False)
323
 
324
  convert_btn = gr.Button("πŸš€ Convert & Upload", variant="primary")
325
 
 
329
 
330
  convert_btn.click(
331
  fn=process_and_upload_fp8,
332
+ inputs=[
333
+ source_type,
334
+ repo_url,
335
+ safetensors_filename,
336
+ fp8_format,
337
+ target_type,
338
+ new_repo_id,
339
+ hf_token,
340
+ modelscope_token,
341
+ private_repo
342
+ ],
343
  outputs=[repo_link_output, status_output],
344
  show_progress=True
345
  )
346
 
347
  gr.Examples(
348
  examples=[
349
+ ["huggingface", "https://huggingface.co/Yabo/FramePainter", "unet_diffusion_pytorch_model.safetensors", "e5m2", "huggingface"]
350
  ],
351
+ inputs=[source_type, repo_url, safetensors_filename, fp8_format, target_type]
352
  )
353
 
354
  demo.launch()