codemichaeld commited on
Commit
fdd626d
Β·
verified Β·
1 Parent(s): 1b040ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -32
app.py CHANGED
@@ -9,6 +9,7 @@ from pathlib import Path
9
  from huggingface_hub import HfApi, hf_hub_download
10
  from safetensors.torch import load_file, save_file
11
  import torch
 
12
 
13
  try:
14
  from modelscope.hub.file_download import model_file_download as ms_file_download
@@ -17,8 +18,16 @@ try:
17
  except ImportError:
18
  MODELScope_AVAILABLE = False
19
 
20
- def convert_safetensors_to_fp8_with_delta(safetensors_path, output_dir, fp8_format, progress=gr.Progress()):
21
- progress(0.1, desc="Starting FP8 conversion with delta...")
 
 
 
 
 
 
 
 
22
 
23
  try:
24
  def read_safetensors_metadata(path):
@@ -40,31 +49,41 @@ def convert_safetensors_to_fp8_with_delta(safetensors_path, output_dir, fp8_form
40
  fp8_dtype = torch.float8_e4m3fn
41
 
42
  sd_fp8 = {}
43
- sd_delta = {}
44
  total = len(state_dict)
 
45
 
46
  for i, key in enumerate(state_dict):
47
  progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
48
  weight = state_dict[key]
49
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
50
  fp8_weight = weight.to(fp8_dtype)
51
- fp8_recon = fp8_weight.to(weight.dtype)
52
- delta = weight - fp8_recon
53
  sd_fp8[key] = fp8_weight
54
- sd_delta[f"delta.{key}"] = delta
 
 
 
 
 
 
 
 
 
 
55
  else:
56
  sd_fp8[key] = weight
57
 
58
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
59
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
60
- delta_path = os.path.join(output_dir, f"{base_name}-fp8-delta.safetensors")
61
 
62
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
63
- save_file(sd_delta, delta_path, metadata={"format": "pt", "source": "fp8_delta", "fp8_format": fp8_format})
 
64
 
65
- progress(0.9, desc="Saved FP8 and delta files.")
66
- progress(1.0, desc="βœ… FP8 + delta generation complete!")
67
- return True, f"FP8 ({fp8_format}) and delta saved."
68
 
69
  except Exception as e:
70
  return False, str(e)
@@ -130,6 +149,7 @@ def process_and_upload_fp8(
130
  repo_url,
131
  safetensors_filename,
132
  fp8_format,
 
133
  target_type,
134
  new_repo_id,
135
  hf_token,
@@ -154,8 +174,10 @@ def process_and_upload_fp8(
154
  source_type, repo_url, safetensors_filename, hf_token, progress
155
  )
156
 
157
- progress(0.25, desc="Converting to FP8 with delta...")
158
- success, msg = convert_safetensors_to_fp8_with_delta(safetensors_path, output_dir, fp8_format, progress)
 
 
159
  if not success:
160
  return None, f"❌ Conversion failed: {msg}", ""
161
 
@@ -165,42 +187,47 @@ def process_and_upload_fp8(
165
  )
166
 
167
  base_name = os.path.splitext(safetensors_filename)[0]
 
168
  readme = f"""---
169
  library_name: diffusers
170
  tags:
171
  - fp8
172
  - safetensors
173
- - delta-compensation
 
174
  - diffusion
175
  - converted-by-gradio
176
  ---
177
 
178
- # FP8 Model with Delta Compensation
179
 
180
  - **Source**: `{repo_url}`
181
  - **File**: `{safetensors_filename}`
182
  - **FP8 Format**: `{fp8_format.upper()}`
183
- - **Delta File**: `{base_name}-fp8-delta.safetensors`
 
184
 
185
  ## Usage (Inference)
186
 
187
- To restore near-original precision:
188
-
189
  ```python
190
- import torch
191
  from safetensors.torch import load_file
 
192
 
 
193
  fp8_state = load_file("{base_name}-fp8-{fp8_format}.safetensors")
194
- delta_state = load_file("{base_name}-fp8-delta.safetensors")
195
 
196
- restored_state = {{}}
 
197
  for key in fp8_state:
198
- if f"delta.{{key}}" in delta_state:
 
 
 
199
  fp8_weight = fp8_state[key].to(torch.float32)
200
- delta = delta_state[f"delta.{{key}}"]
201
- restored_state[key] = fp8_weight + delta
202
  else:
203
- restored_state[key] = fp8_state[key].to(torch.float32)
204
  ```
205
 
206
  > Requires PyTorch β‰₯ 2.1 for FP8 support.
@@ -221,9 +248,9 @@ for key in fp8_state:
221
  result_html = f"""
222
  βœ… Success!
223
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
224
- Includes: FP8 model + delta compensation file.
225
  """
226
- return gr.HTML(result_html), "βœ… FP8 + delta upload successful!", ""
227
 
228
  except Exception as e:
229
  return None, f"❌ Error: {str(e)}", ""
@@ -232,9 +259,9 @@ Includes: FP8 model + delta compensation file.
232
  shutil.rmtree(temp_dir, ignore_errors=True)
233
  shutil.rmtree(output_dir, ignore_errors=True)
234
 
235
- with gr.Blocks(title="FP8 + Delta Converter (HF ↔ ModelScope)") as demo:
236
- gr.Markdown("# πŸ”„ FP8 Pruner with Delta Compensation")
237
- gr.Markdown("Convert `.safetensors` β†’ **FP8** + **delta file** for precision recovery. Supports Hugging Face ↔ ModelScope.")
238
 
239
  with gr.Row():
240
  with gr.Column():
@@ -242,6 +269,7 @@ with gr.Blocks(title="FP8 + Delta Converter (HF ↔ ModelScope)") as demo:
242
  repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
243
  safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors")
244
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
 
245
  hf_token = gr.Textbox(label="HF Token (only if using HF)", type="password")
246
  modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
247
  with gr.Column():
@@ -260,6 +288,7 @@ with gr.Blocks(title="FP8 + Delta Converter (HF ↔ ModelScope)") as demo:
260
  repo_url,
261
  safetensors_filename,
262
  fp8_format,
 
263
  target_type,
264
  new_repo_id,
265
  hf_token,
@@ -272,9 +301,9 @@ with gr.Blocks(title="FP8 + Delta Converter (HF ↔ ModelScope)") as demo:
272
 
273
  gr.Examples(
274
  examples=[
275
- ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", "modelscope"]
276
  ],
277
- inputs=[source_type, repo_url, safetensors_filename, fp8_format, target_type]
278
  )
279
 
280
  demo.launch()
 
9
  from huggingface_hub import HfApi, hf_hub_download
10
  from safetensors.torch import load_file, save_file
11
  import torch
12
+ import torch.nn.functional as F
13
 
14
  try:
15
  from modelscope.hub.file_download import model_file_download as ms_file_download
 
18
  except ImportError:
19
  MODELScope_AVAILABLE = False
20
 
21
+ def low_rank_decomposition(weight, rank=64):
22
+ if weight.ndim != 2:
23
+ return None
24
+ U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
25
+ U = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
26
+ Vh = torch.diag(torch.sqrt(S[:rank])) @ Vh[:rank, :]
27
+ return U.contiguous(), Vh.contiguous()
28
+
29
+ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=64, progress=gr.Progress()):
30
+ progress(0.1, desc="Starting FP8 conversion with LoRA extraction...")
31
 
32
  try:
33
  def read_safetensors_metadata(path):
 
49
  fp8_dtype = torch.float8_e4m3fn
50
 
51
  sd_fp8 = {}
52
+ lora_weights = {}
53
  total = len(state_dict)
54
+ lora_keys = []
55
 
56
  for i, key in enumerate(state_dict):
57
  progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
58
  weight = state_dict[key]
59
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
60
  fp8_weight = weight.to(fp8_dtype)
 
 
61
  sd_fp8[key] = fp8_weight
62
+
63
+ # Attempt LoRA decomposition only for 2D tensors
64
+ if weight.ndim == 2 and min(weight.shape) > lora_rank:
65
+ try:
66
+ U, V = low_rank_decomposition(weight, rank=lora_rank)
67
+ if U is not None and V is not None:
68
+ lora_weights[f"lora_A.{key}"] = U.to(torch.float16)
69
+ lora_weights[f"lora_B.{key}"] = V.to(torch.float16)
70
+ lora_keys.append(key)
71
+ except Exception:
72
+ pass
73
  else:
74
  sd_fp8[key] = weight
75
 
76
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
77
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
78
+ lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}.safetensors")
79
 
80
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
81
+ if lora_weights:
82
+ save_file(lora_weights, lora_path, metadata={"format": "pt", "lora_rank": str(lora_rank)})
83
 
84
+ progress(0.9, desc="Saved FP8 and LoRA files.")
85
+ progress(1.0, desc="βœ… FP8 + LoRA extraction complete!")
86
+ return True, f"FP8 ({fp8_format}) and rank-{lora_rank} LoRA saved."
87
 
88
  except Exception as e:
89
  return False, str(e)
 
149
  repo_url,
150
  safetensors_filename,
151
  fp8_format,
152
+ lora_rank,
153
  target_type,
154
  new_repo_id,
155
  hf_token,
 
174
  source_type, repo_url, safetensors_filename, hf_token, progress
175
  )
176
 
177
+ progress(0.25, desc="Converting to FP8 with LoRA extraction...")
178
+ success, msg = convert_safetensors_to_fp8_with_lora(
179
+ safetensors_path, output_dir, fp8_format, lora_rank, progress
180
+ )
181
  if not success:
182
  return None, f"❌ Conversion failed: {msg}", ""
183
 
 
187
  )
188
 
189
  base_name = os.path.splitext(safetensors_filename)[0]
190
+ lora_filename = f"{base_name}-lora-r{lora_rank}.safetensors"
191
  readme = f"""---
192
  library_name: diffusers
193
  tags:
194
  - fp8
195
  - safetensors
196
+ - lora
197
+ - low-rank
198
  - diffusion
199
  - converted-by-gradio
200
  ---
201
 
202
+ # FP8 Model with Low-Rank LoRA
203
 
204
  - **Source**: `{repo_url}`
205
  - **File**: `{safetensors_filename}`
206
  - **FP8 Format**: `{fp8_format.upper()}`
207
+ - **LoRA Rank**: {lora_rank}
208
+ - **LoRA File**: `{lora_filename}`
209
 
210
  ## Usage (Inference)
211
 
 
 
212
  ```python
 
213
  from safetensors.torch import load_file
214
+ import torch
215
 
216
+ # Load FP8 model
217
  fp8_state = load_file("{base_name}-fp8-{fp8_format}.safetensors")
218
+ lora_state = load_file("{lora_filename}")
219
 
220
+ # Reconstruct approximate original weights
221
+ reconstructed = {{}}
222
  for key in fp8_state:
223
+ if f"lora_A.{{key}}" in lora_state and f"lora_B.{{key}}" in lora_state:
224
+ A = lora_state[f"lora_A.{{key}}"].to(torch.float32)
225
+ B = lora_state[f"lora_B.{{key}}"].to(torch.float32)
226
+ lora_weight = B @ A # (rank, out) @ (in, rank) -> (out, in)
227
  fp8_weight = fp8_state[key].to(torch.float32)
228
+ reconstructed[key] = fp8_weight + lora_weight
 
229
  else:
230
+ reconstructed[key] = fp8_state[key].to(torch.float32)
231
  ```
232
 
233
  > Requires PyTorch β‰₯ 2.1 for FP8 support.
 
248
  result_html = f"""
249
  βœ… Success!
250
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
251
+ Includes: FP8 model + rank-{lora_rank} LoRA.
252
  """
253
+ return gr.HTML(result_html), "βœ… FP8 + LoRA upload successful!", ""
254
 
255
  except Exception as e:
256
  return None, f"❌ Error: {str(e)}", ""
 
259
  shutil.rmtree(temp_dir, ignore_errors=True)
260
  shutil.rmtree(output_dir, ignore_errors=True)
261
 
262
+ with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo:
263
+ gr.Markdown("# πŸ”„ FP8 Pruner with Low-Rank LoRA Extraction")
264
+ gr.Markdown("Convert `.safetensors` β†’ **FP8** + **compact LoRA** for precision recovery. Supports Hugging Face ↔ ModelScope.")
265
 
266
  with gr.Row():
267
  with gr.Column():
 
269
  repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
270
  safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors")
271
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
272
+ lora_rank = gr.Slider(minimum=8, maximum=256, step=8, value=64, label="LoRA Rank")
273
  hf_token = gr.Textbox(label="HF Token (only if using HF)", type="password")
274
  modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
275
  with gr.Column():
 
288
  repo_url,
289
  safetensors_filename,
290
  fp8_format,
291
+ lora_rank,
292
  target_type,
293
  new_repo_id,
294
  hf_token,
 
301
 
302
  gr.Examples(
303
  examples=[
304
+ ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", 64, "modelscope"]
305
  ],
306
+ inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, target_type]
307
  )
308
 
309
  demo.launch()