codemichaeld commited on
Commit
4f1244e
Β·
verified Β·
1 Parent(s): 311ef01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -153
app.py CHANGED
@@ -17,48 +17,92 @@ except ImportError:
17
  MODELScope_AVAILABLE = False
18
 
19
  def low_rank_decomposition(weight, rank=64):
20
- """Standard LoRA decomposition for 2D tensors only."""
21
  if weight.ndim != 2:
22
  return None, None
23
 
24
  try:
25
- U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
26
- U = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
27
- Vh = torch.diag(torch.sqrt(S[:rank])) @ Vh[:rank, :]
28
- return U.contiguous(), Vh.contiguous()
29
- except Exception:
 
 
 
 
 
 
 
 
 
30
  return None, None
31
 
32
  def extract_correction_factors(original_weight, fp8_weight):
33
  """Extract per-channel/tensor correction factors (difference method)."""
34
  with torch.no_grad():
 
35
  orig = original_weight.float()
36
  quant = fp8_weight.float()
 
 
37
  error = orig - quant
38
 
 
39
  error_norm = torch.norm(error)
40
  orig_norm = torch.norm(orig)
41
  if orig_norm > 1e-6 and error_norm / orig_norm < 0.01:
42
  return None
43
 
44
- # For 4D tensors (VAE/conv layers)
45
  if orig.ndim == 4:
 
46
  channel_dim = 0
 
47
  channel_mean = error.mean(dim=tuple(i for i in range(1, orig.ndim)), keepdim=True)
48
  return channel_mean.to(original_weight.dtype)
49
 
50
  # For 2D tensors (linear layers)
51
  elif orig.ndim == 2:
 
52
  row_mean = error.mean(dim=1, keepdim=True)
53
  return row_mean.to(original_weight.dtype)
54
 
55
- # For 1D tensors (bias, etc.)
56
  else:
57
  return error.mean().to(original_weight.dtype)
58
 
59
- def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_format, recovery_config, progress=gr.Progress()):
60
- progress(0.1, desc="Starting FP8 conversion with precision recovery...")
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
  def read_safetensors_metadata(path):
64
  with open(path, 'rb') as f:
@@ -69,122 +113,128 @@ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_f
69
 
70
  metadata = read_safetensors_metadata(safetensors_path)
71
  progress(0.2, desc="Loaded metadata.")
 
 
72
  state_dict = load_file(safetensors_path)
73
- progress(0.4, desc="Loaded weights.")
74
 
75
- if fp8_format == "e5m2":
76
- fp8_dtype = torch.float8_e5m2
77
- else:
78
- fp8_dtype = torch.float8_e4m3fn
 
 
79
 
 
80
  sd_fp8 = {}
81
  recovery_weights = {}
82
  stats = {
83
  "total_layers": len(state_dict),
84
  "processed_layers": 0,
85
  "skipped_layers": [],
86
- "recovery_type_counts": {"lora": 0, "diff": 0}
 
87
  }
88
 
 
 
 
 
 
 
 
 
 
 
 
89
  total = len(state_dict)
90
  for i, key in enumerate(state_dict):
91
- progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
92
  weight = state_dict[key]
93
- lower_key = key.lower()
94
 
 
95
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
96
  fp8_weight = weight.to(fp8_dtype)
97
  sd_fp8[key] = fp8_weight
98
-
99
- # Match key against recovery config rules
100
- recovery_method = "none"
101
- lora_rank = 64
102
-
103
- for rule in recovery_config:
104
- element_pattern = rule.get("element", "").lower()
105
- method = rule.get("method", "none")
106
-
107
- if element_pattern == "all" or element_pattern in lower_key:
108
- recovery_method = method
109
- if method == "lora":
110
- lora_rank = rule.get("rank", 64)
111
- break
112
-
113
- if recovery_method == "lora" and weight.ndim == 2 and min(weight.shape) > lora_rank:
114
- try:
115
- U, V = low_rank_decomposition(weight, rank=lora_rank)
116
- if U is not None and V is not None:
117
- recovery_weights[f"lora_A.{key}"] = U.to(torch.float16)
118
- recovery_weights[f"lora_B.{key}"] = V.to(torch.float16)
119
- stats["processed_layers"] += 1
120
- stats["recovery_type_counts"]["lora"] += 1
121
- except Exception:
122
- stats["skipped_layers"].append(f"{key}: lora failed")
123
-
124
- elif recovery_method == "diff":
125
- try:
126
- corr = extract_correction_factors(weight, fp8_weight)
127
- if corr is not None:
128
- recovery_weights[f"diff.{key}"] = corr
129
- stats["processed_layers"] += 1
130
- stats["recovery_type_counts"]["diff"] += 1
131
- except Exception:
132
- stats["skipped_layers"].append(f"{key}: diff failed")
133
-
134
- else:
135
- stats["skipped_layers"].append(f"{key}: {recovery_method}")
136
  else:
137
  sd_fp8[key] = weight
138
  stats["skipped_layers"].append(f"{key}: non-float dtype")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
 
140
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
141
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
142
- recovery_path = os.path.join(output_dir, f"{base_name}-recovery.safetensors")
143
-
144
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
145
 
 
 
146
  if recovery_weights:
147
- save_file(recovery_weights, recovery_path, metadata={
 
148
  "format": "pt",
149
  "fp8_format": fp8_format,
150
- "recovery_config": json.dumps(recovery_config),
151
  "stats": json.dumps(stats)
152
- })
 
153
 
154
  progress(0.9, desc="Saved FP8 and recovery files.")
155
- progress(1.0, desc="βœ… FP8 + recovery extraction complete!")
156
 
157
- stats_msg = f"FP8 ({fp8_format}) and recovery saved.\n"
 
158
  stats_msg += f"- Total layers: {stats['total_layers']}\n"
159
- stats_msg += f"- Processed: {stats['processed_layers']} ({stats['recovery_type_counts']['lora']} LoRA + {stats['recovery_type_counts']['diff']} Diff)\n"
 
 
160
 
161
- if stats["processed_layers"] == 0:
162
- stats_msg += "\n⚠️ No recovery weights generated. Check your rules and rank settings."
163
 
164
- return True, stats_msg, stats
 
165
 
166
  except Exception as e:
167
- return False, str(e), None
168
-
169
- def generate_config_from_rules(rules_input):
170
- """Parse multi-line rule input into config."""
171
- config = []
172
- for line in rules_input.strip().split('\n'):
173
- line = line.strip()
174
- if not line or line.startswith('#'):
175
- continue
176
- parts = [p.strip() for p in line.split(',')]
177
- if len(parts) >= 2:
178
- element = parts[0]
179
- method = parts[1].lower()
180
- rank = 64
181
- if method == "lora" and len(parts) >= 3:
182
- try:
183
- rank = int(parts[2])
184
- except ValueError:
185
- pass
186
- config.append({"element": element, "method": method, "rank": rank})
187
- return config
188
 
189
  def parse_hf_url(url):
190
  url = url.strip().rstrip("/")
@@ -247,7 +297,7 @@ def process_and_upload_fp8(
247
  repo_url,
248
  safetensors_filename,
249
  fp8_format,
250
- recovery_rules,
251
  target_type,
252
  new_repo_id,
253
  hf_token,
@@ -256,15 +306,27 @@ def process_and_upload_fp8(
256
  progress=gr.Progress()
257
  ):
258
  if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id):
259
- return None, "❌ Invalid repo ID format. Use 'username/model-name'.", ""
260
  if source_type == "huggingface" and not hf_token:
261
- return None, "❌ Hugging Face token required for source.", ""
262
  if target_type == "huggingface" and not hf_token:
263
- return None, "❌ Hugging Face token required for target.", ""
264
 
265
- recovery_config = generate_config_from_rules(recovery_rules)
266
- if not recovery_config:
267
- recovery_config = [{"element": "all", "method": "none"}]
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  temp_dir = None
270
  output_dir = tempfile.mkdtemp()
@@ -274,22 +336,23 @@ def process_and_upload_fp8(
274
  source_type, repo_url, safetensors_filename, hf_token, progress
275
  )
276
 
277
- progress(0.25, desc="Converting to FP8 with recovery...")
278
- success, msg, stats = convert_safetensors_to_fp8_with_recovery(
279
- safetensors_path, output_dir, fp8_format, recovery_config, progress
280
  )
281
 
282
  if not success:
283
- return None, f"❌ Conversion failed: {msg}", ""
284
 
285
  progress(0.9, desc="Uploading...")
286
  repo_url_final = upload_to_target(
287
  target_type, new_repo_id, output_dir, fp8_format, hf_token, modelscope_token, private_repo
288
  )
289
 
 
290
  base_name = os.path.splitext(safetensors_filename)[0]
291
- fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
292
- recovery_filename = f"{base_name}-recovery.safetensors"
293
 
294
  readme = f"""---
295
  library_name: diffusers
@@ -300,44 +363,61 @@ tags:
300
  - mixed-method
301
  - converted-by-gradio
302
  ---
303
- # FP8 Model with Custom Precision Recovery
304
-
305
  - **Source**: `{repo_url}`
306
- - **File**: `{safetensors_filename}`
307
  - **FP8 Format**: `{fp8_format.upper()}`
308
- - **Recovery File**: `{recovery_filename}` (contains both LoRA and Difference weights)
 
309
 
310
- ## Recovery Rules Used
311
- ```
312
- {recovery_rules}
313
  ```
314
 
315
- ## Usage
316
  ```python
317
  from safetensors.torch import load_file
318
  import torch
319
 
 
320
  fp8_state = load_file("{fp8_filename}")
321
- recovery_state = load_file("{recovery_filename}")
322
 
 
 
 
 
323
  reconstructed = {{}}
324
  for key in fp8_state:
325
- fp8_weight = fp8_state[key].to(torch.float32)
326
 
327
- # Apply LoRA if present
328
  if f"lora_A.{{key}}" in recovery_state and f"lora_B.{{key}}" in recovery_state:
329
  A = recovery_state[f"lora_A.{{key}}"].to(torch.float32)
330
  B = recovery_state[f"lora_B.{{key}}"].to(torch.float32)
 
331
  lora_weight = B @ A
332
  fp8_weight = fp8_weight + lora_weight
333
 
334
- # Apply Difference if present
335
  if f"diff.{{key}}" in recovery_state:
336
  diff = recovery_state[f"diff.{{key}}"].to(torch.float32)
337
  fp8_weight = fp8_weight + diff
338
 
339
  reconstructed[key] = fp8_weight
 
 
 
340
  ```
 
 
 
 
 
 
 
 
 
341
  """
342
 
343
  with open(os.path.join(output_dir, "README.md"), "w") as f:
@@ -353,24 +433,39 @@ for key in fp8_state:
353
  )
354
 
355
  progress(1.0, desc="βœ… Done!")
 
 
 
 
 
 
356
  result_html = f"""
357
  βœ… Success!
358
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
359
- Includes FP8 + custom recovery weights.
 
 
360
  """
361
- return gr.HTML(result_html), "βœ… FP8 + recovery upload successful!", msg
 
 
 
 
 
362
 
363
  except Exception as e:
364
- return None, f"❌ Error: {str(e)}", ""
 
 
365
 
366
  finally:
367
  if temp_dir:
368
  shutil.rmtree(temp_dir, ignore_errors=True)
369
  shutil.rmtree(output_dir, ignore_errors=True)
370
 
371
- with gr.Blocks(title="FP8 + Custom Recovery Extractor") as demo:
372
- gr.Markdown("# πŸ”„ FP8 Quantizer with Per-Layer Recovery Control")
373
- gr.Markdown("Specify **exact recovery method per layer/tensor** using pattern matching. Supports LoRA and Difference methods simultaneously.")
374
 
375
  with gr.Row():
376
  with gr.Column():
@@ -381,21 +476,39 @@ with gr.Blocks(title="FP8 + Custom Recovery Extractor") as demo:
381
  with gr.Accordion("FP8 Settings", open=True):
382
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
383
 
384
- with gr.Accordion("Recovery Rules (Layer/Tensor Level)", open=True):
385
  gr.Markdown("""
386
- Define recovery rules **one per line** in format:
387
- `layer_pattern, method [, rank]`
 
 
 
 
 
 
 
 
388
 
389
- - `layer_pattern`: substring to match in weight key (case-insensitive)
390
- - `method`: `lora` or `diff` or `none`
391
- - `rank`: LoRA rank (only for `lora` method)
392
 
393
- **Rules are applied in order** – first match wins.
394
  """)
395
- recovery_rules = gr.Textbox(
396
- value="vae, diff\nencoder, diff\ndecoder, diff\ntext, lora, 64\nattn, lora, 128\nall, none",
397
- lines=8,
398
- label="Recovery Rules"
 
 
 
 
 
 
 
 
 
 
399
  )
400
 
401
  with gr.Accordion("Authentication", open=False):
@@ -409,6 +522,7 @@ with gr.Blocks(title="FP8 + Custom Recovery Extractor") as demo:
409
 
410
  status_output = gr.Markdown()
411
  detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10)
 
412
 
413
  convert_btn = gr.Button("πŸš€ Convert & Upload", variant="primary")
414
  repo_link_output = gr.HTML()
@@ -420,59 +534,87 @@ with gr.Blocks(title="FP8 + Custom Recovery Extractor") as demo:
420
  repo_url,
421
  safetensors_filename,
422
  fp8_format,
423
- recovery_rules,
424
  target_type,
425
  new_repo_id,
426
  hf_token,
427
  modelscope_token,
428
  private_repo
429
  ],
430
- outputs=[repo_link_output, status_output, detailed_log],
431
  show_progress=True
432
  )
433
 
434
  gr.Examples(
435
  examples=[
436
  [
437
- "huggingface",
438
- "https://huggingface.co/stabilityai/sdxl-vae",
439
- "diffusion_pytorch_model.safetensors",
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  "e5m2",
441
- "vae, diff\nencoder, diff\ndecoder, diff\nall, none",
 
 
 
 
442
  "huggingface"
443
  ],
444
  [
445
  "huggingface",
446
- "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
447
- "model.safetensors",
448
  "e5m2",
449
- "text, lora, 64\nemb, lora, 64\nall, none",
 
 
 
 
 
 
450
  "huggingface"
451
  ]
452
  ],
453
- inputs=[source_type, repo_url, safetensors_filename, fp8_format, recovery_rules, target_type],
454
  label="Example Conversions"
455
  )
456
 
457
  gr.Markdown("""
458
- ## πŸ’‘ Recovery Strategy Guide
459
 
460
- ### **Difference Method (Recommended for VAE/Convs)**
461
- - Use for: `vae`, `encoder`, `decoder`, `conv` layers
462
- - Captures exact quantization error
463
- - Works with 4D tensors that LoRA cannot handle
 
 
 
 
464
 
465
- ### **LoRA Method (Recommended for Attention/Linear)**
466
- - Use for: `text`, `attn`, `mlp`, `transformer` layers
467
- - Use rank 32-128 depending on layer importance
468
- - Only works on 2D tensors
 
469
 
470
  ### **Rule Ordering Tips**
471
- - Put specific patterns first (`vae.encoder`) before general ones (`vae`)
472
- - End with `all, none` to set default behavior
473
- - Layer names are **case-insensitive**
474
 
475
- > This implementation restores the successful VAE difference method while adding full per-layer control.
476
  """)
477
 
478
  demo.launch()
 
17
  MODELScope_AVAILABLE = False
18
 
19
  def low_rank_decomposition(weight, rank=64):
20
+ """Standard LoRA decomposition for 2D tensors."""
21
  if weight.ndim != 2:
22
  return None, None
23
 
24
  try:
25
+ weight_f32 = weight.float()
26
+ U, S, Vh = torch.linalg.svd(weight_f32, full_matrices=False)
27
+
28
+ actual_rank = min(rank, len(S))
29
+ if actual_rank < 4:
30
+ return None, None
31
+
32
+ # Standard LoRA factorization: W = W_B @ W_A
33
+ W_A = Vh[:actual_rank, :].contiguous() # [rank, in_features]
34
+ W_B = U[:, :actual_rank] @ torch.diag(S[:actual_rank]) # [out_features, rank]
35
+
36
+ return W_A.to(torch.float16), W_B.to(torch.float16)
37
+ except Exception as e:
38
+ print(f"Decomposition error: {e}")
39
  return None, None
40
 
41
  def extract_correction_factors(original_weight, fp8_weight):
42
  """Extract per-channel/tensor correction factors (difference method)."""
43
  with torch.no_grad():
44
+ # Convert to float32 for precision
45
  orig = original_weight.float()
46
  quant = fp8_weight.float()
47
+
48
+ # Compute error (what needs to be added to FP8 to recover original)
49
  error = orig - quant
50
 
51
+ # Skip if error is negligible
52
  error_norm = torch.norm(error)
53
  orig_norm = torch.norm(orig)
54
  if orig_norm > 1e-6 and error_norm / orig_norm < 0.01:
55
  return None
56
 
57
+ # For 4D tensors (common in VAE, CNNs)
58
  if orig.ndim == 4:
59
+ # Channel dimension is typically dimension 0 (output channels)
60
  channel_dim = 0
61
+ # Compute mean error per output channel
62
  channel_mean = error.mean(dim=tuple(i for i in range(1, orig.ndim)), keepdim=True)
63
  return channel_mean.to(original_weight.dtype)
64
 
65
  # For 2D tensors (linear layers)
66
  elif orig.ndim == 2:
67
+ # Compute mean error per output row
68
  row_mean = error.mean(dim=1, keepdim=True)
69
  return row_mean.to(original_weight.dtype)
70
 
71
+ # For 1D tensors (bias, batchnorm)
72
  else:
73
  return error.mean().to(original_weight.dtype)
74
 
75
+ def analyze_model_architecture(state_dict):
76
+ """Auto-detect model architecture and components."""
77
+ keys = " ".join(state_dict.keys()).lower()
78
+ components = {
79
+ "text_encoder": False,
80
+ "unet": False,
81
+ "vae": False,
82
+ "clip": False,
83
+ "transformer": False
84
+ }
85
 
86
+ # Detect components based on key patterns
87
+ if "text" in keys or "emb" in keys or ("encoder" in keys and "vae" not in keys):
88
+ components["text_encoder"] = True
89
+ if "clip" in keys or "vision" in keys:
90
+ components["clip"] = True
91
+
92
+ if "unet" in keys or ("down_blocks" in keys and "up_blocks" in keys) or ("input_blocks" in keys and "output_blocks" in keys):
93
+ components["unet"] = True
94
+ if "transformer" in keys or "attn" in keys:
95
+ components["transformer"] = True
96
+
97
+ if "vae" in keys or ("encoder" in keys and "decoder" in keys) or "quant_conv" in keys or "post_quant" in keys:
98
+ components["vae"] = True
99
+
100
+ return components
101
+
102
+ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_format,
103
+ recovery_configs, progress=gr.Progress()):
104
+ """Convert model to FP8 with customizable per-element recovery strategies."""
105
+ progress(0.1, desc="Starting FP8 conversion with precision recovery...")
106
  try:
107
  def read_safetensors_metadata(path):
108
  with open(path, 'rb') as f:
 
113
 
114
  metadata = read_safetensors_metadata(safetensors_path)
115
  progress(0.2, desc="Loaded metadata.")
116
+
117
+ # Load model
118
  state_dict = load_file(safetensors_path)
119
+ progress(0.3, desc="Loaded model weights.")
120
 
121
+ # Auto-detect architecture
122
+ detected_components = analyze_model_architecture(state_dict)
123
+ print(f"Detected components: {detected_components}")
124
+
125
+ # Setup FP8 format
126
+ fp8_dtype = torch.float8_e5m2 if fp8_format == "e5m2" else torch.float8_e4m3fn
127
 
128
+ # Initialize outputs
129
  sd_fp8 = {}
130
  recovery_weights = {}
131
  stats = {
132
  "total_layers": len(state_dict),
133
  "processed_layers": 0,
134
  "skipped_layers": [],
135
+ "detected_components": detected_components,
136
+ "recovery_counts": {"lora": 0, "diff": 0}
137
  }
138
 
139
+ # Create a mapping from layer keys to recovery config
140
+ layer_recovery_map = {}
141
+ for config in recovery_configs:
142
+ element_pattern = config["element"].lower()
143
+ for key in state_dict:
144
+ if element_pattern == "all" or element_pattern in key.lower():
145
+ # Only set if not already set (first match wins)
146
+ if key not in layer_recovery_map:
147
+ layer_recovery_map[key] = config
148
+
149
+ # Process each tensor
150
  total = len(state_dict)
151
  for i, key in enumerate(state_dict):
152
+ progress(0.3 + 0.5 * (i / total), desc=f"Processing {i+1}/{total}: {key.split('.')[-1]}")
153
  weight = state_dict[key]
 
154
 
155
+ # Convert to FP8
156
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
157
  fp8_weight = weight.to(fp8_dtype)
158
  sd_fp8[key] = fp8_weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  else:
160
  sd_fp8[key] = weight
161
  stats["skipped_layers"].append(f"{key}: non-float dtype")
162
+ continue
163
+
164
+ # Get recovery config for this layer
165
+ recovery_config = layer_recovery_map.get(key)
166
+ if not recovery_config or recovery_config["method"] == "none":
167
+ stats["skipped_layers"].append(f"{key}: no recovery configured")
168
+ continue
169
+
170
+ try:
171
+ method = recovery_config["method"]
172
+ if method == "lora" and weight.ndim == 2:
173
+ # LoRA recovery for 2D tensors only
174
+ rank = recovery_config.get("rank", 64)
175
+ # Adjust rank for smaller matrices
176
+ adjusted_rank = min(rank, min(weight.shape) // 2)
177
+ if adjusted_rank >= 4:
178
+ A, B = low_rank_decomposition(weight, rank=adjusted_rank)
179
+ if A is not None and B is not None:
180
+ recovery_weights[f"lora_A.{key}"] = A
181
+ recovery_weights[f"lora_B.{key}"] = B
182
+ stats["processed_layers"] += 1
183
+ stats["recovery_counts"]["lora"] += 1
184
+ continue
185
+
186
+ if method == "diff":
187
+ # Difference/correction recovery for any tensor type
188
+ corr = extract_correction_factors(weight, fp8_weight)
189
+ if corr is not None:
190
+ recovery_weights[f"diff.{key}"] = corr
191
+ stats["processed_layers"] += 1
192
+ stats["recovery_counts"]["diff"] += 1
193
+ continue
194
+
195
+ # If we get here, recovery was configured but couldn't be applied
196
+ reason = "2D tensor required" if method == "lora" and weight.ndim != 2 else "decomposition failed"
197
+ stats["skipped_layers"].append(f"{key}: {method} recovery failed ({reason})")
198
+
199
+ except Exception as e:
200
+ stats["skipped_layers"].append(f"{key}: error - {str(e)}")
201
 
202
+ # Save FP8 model
203
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
204
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
 
 
205
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
206
 
207
+ # Save recovery weights if any were generated
208
+ recovery_path = None
209
  if recovery_weights:
210
+ recovery_path = os.path.join(output_dir, f"{base_name}-recovery.safetensors")
211
+ recovery_metadata = {
212
  "format": "pt",
213
  "fp8_format": fp8_format,
214
+ "recovery_config": json.dumps(recovery_configs),
215
  "stats": json.dumps(stats)
216
+ }
217
+ save_file(recovery_weights, recovery_path, metadata=recovery_metadata)
218
 
219
  progress(0.9, desc="Saved FP8 and recovery files.")
 
220
 
221
+ # Generate stats message
222
+ stats_msg = f"FP8 ({fp8_format}) conversion complete with precision recovery:\n"
223
  stats_msg += f"- Total layers: {stats['total_layers']}\n"
224
+ stats_msg += f"- Layers with recovery: {stats['processed_layers']}\n"
225
+ stats_msg += f" - LoRA recovery: {stats['recovery_counts']['lora']}\n"
226
+ stats_msg += f" - Difference recovery: {stats['recovery_counts']['diff']}\n"
227
 
228
+ if not recovery_weights:
229
+ stats_msg += "\n⚠️ No recovery weights were generated. All layers use pure FP8."
230
 
231
+ progress(1.0, desc="βœ… FP8 conversion with precision recovery complete!")
232
+ return True, stats_msg, stats, fp8_path, recovery_path
233
 
234
  except Exception as e:
235
+ import traceback
236
+ error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
237
+ return False, error_msg, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  def parse_hf_url(url):
240
  url = url.strip().rstrip("/")
 
297
  repo_url,
298
  safetensors_filename,
299
  fp8_format,
300
+ recovery_configs_json,
301
  target_type,
302
  new_repo_id,
303
  hf_token,
 
306
  progress=gr.Progress()
307
  ):
308
  if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id):
309
+ return None, "❌ Invalid repo ID format. Use 'username/model-name'.", "", ""
310
  if source_type == "huggingface" and not hf_token:
311
+ return None, "❌ Hugging Face token required for source.", "", ""
312
  if target_type == "huggingface" and not hf_token:
313
+ return None, "❌ Hugging Face token required for target.", "", ""
314
 
315
+ # Parse recovery configs
316
+ try:
317
+ recovery_configs = json.loads(recovery_configs_json)
318
+ except json.JSONDecodeError:
319
+ return None, "❌ Invalid recovery configuration JSON.", "", ""
320
+
321
+ # Validate config format
322
+ valid_methods = ["none", "lora", "diff"]
323
+ for config in recovery_configs:
324
+ if "element" not in config or "method" not in config:
325
+ return None, "❌ Invalid config format: each config needs 'element' and 'method'", "", ""
326
+ if config["method"] not in valid_methods:
327
+ return None, f"❌ Invalid method: {config['method']}. Use 'none', 'lora', or 'diff'", "", ""
328
+ if config["method"] == "lora" and "rank" not in config:
329
+ return None, "❌ LoRA method requires 'rank' parameter", "", ""
330
 
331
  temp_dir = None
332
  output_dir = tempfile.mkdtemp()
 
336
  source_type, repo_url, safetensors_filename, hf_token, progress
337
  )
338
 
339
+ progress(0.2, desc="Converting to FP8 with precision recovery...")
340
+ success, msg, stats, fp8_path, recovery_path = convert_safetensors_to_fp8_with_recovery(
341
+ safetensors_path, output_dir, fp8_format, recovery_configs, progress
342
  )
343
 
344
  if not success:
345
+ return None, f"❌ Conversion failed: {msg}", "", ""
346
 
347
  progress(0.9, desc="Uploading...")
348
  repo_url_final = upload_to_target(
349
  target_type, new_repo_id, output_dir, fp8_format, hf_token, modelscope_token, private_repo
350
  )
351
 
352
+ # Generate README
353
  base_name = os.path.splitext(safetensors_filename)[0]
354
+ fp8_filename = os.path.basename(fp8_path)
355
+ recovery_filename = os.path.basename(recovery_path) if recovery_path else ""
356
 
357
  readme = f"""---
358
  library_name: diffusers
 
363
  - mixed-method
364
  - converted-by-gradio
365
  ---
366
+ # FP8 Model with Mixed Precision Recovery
 
367
  - **Source**: `{repo_url}`
368
+ - **Original File**: `{safetensors_filename}`
369
  - **FP8 Format**: `{fp8_format.upper()}`
370
+ - **FP8 File**: `{fp8_filename}`
371
+ - **Recovery File**: `{recovery_filename if recovery_filename else "None"}`
372
 
373
+ ## Recovery Configuration
374
+ ```json
375
+ {json.dumps(recovery_configs, indent=2)}
376
  ```
377
 
378
+ ## Usage (Inference)
379
  ```python
380
  from safetensors.torch import load_file
381
  import torch
382
 
383
+ # Load FP8 model
384
  fp8_state = load_file("{fp8_filename}")
 
385
 
386
+ # Load recovery weights if available
387
+ recovery_state = load_file("{recovery_filename}") if "{recovery_filename}" and os.path.exists("{recovery_filename}") else {{}}
388
+
389
+ # Reconstruct high-precision weights
390
  reconstructed = {{}}
391
  for key in fp8_state:
392
+ fp8_weight = fp8_state[key].to(torch.float32) # Convert to float32 for computation
393
 
394
+ # Apply LoRA recovery if available
395
  if f"lora_A.{{key}}" in recovery_state and f"lora_B.{{key}}" in recovery_state:
396
  A = recovery_state[f"lora_A.{{key}}"].to(torch.float32)
397
  B = recovery_state[f"lora_B.{{key}}"].to(torch.float32)
398
+ # Reconstruct the low-rank approximation
399
  lora_weight = B @ A
400
  fp8_weight = fp8_weight + lora_weight
401
 
402
+ # Apply difference recovery if available
403
  if f"diff.{{key}}" in recovery_state:
404
  diff = recovery_state[f"diff.{{key}}"].to(torch.float32)
405
  fp8_weight = fp8_weight + diff
406
 
407
  reconstructed[key] = fp8_weight
408
+
409
+ # Use reconstructed weights in your model
410
+ model.load_state_dict(reconstructed)
411
  ```
412
+
413
+ > **Note**: For best results, use the same recovery configuration during inference as was used during extraction.
414
+ > Requires PyTorch β‰₯ 2.1 for FP8 support.
415
+
416
+ ## Statistics
417
+ - **Total layers**: {stats['total_layers']}
418
+ - **Layers with recovery**: {stats['processed_layers']}
419
+ - LoRA recovery: {stats['recovery_counts']['lora']}
420
+ - Difference recovery: {stats['recovery_counts']['diff']}
421
  """
422
 
423
  with open(os.path.join(output_dir, "README.md"), "w") as f:
 
433
  )
434
 
435
  progress(1.0, desc="βœ… Done!")
436
+
437
+ # Generate result HTML
438
+ recovery_links = []
439
+ if recovery_path:
440
+ recovery_links.append(f"- **Recovery weights**: `{recovery_filename}`")
441
+
442
  result_html = f"""
443
  βœ… Success!
444
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
445
+ Includes:
446
+ - FP8 model: `{fp8_filename}`
447
+ - {chr(10).join(recovery_links)}
448
  """
449
+
450
+ recovery_details = f"Recovery file: {recovery_filename}" if recovery_filename else "No recovery weights generated"
451
+ return (gr.HTML(result_html),
452
+ "βœ… FP8 conversion with precision recovery successful!",
453
+ msg,
454
+ recovery_details)
455
 
456
  except Exception as e:
457
+ import traceback
458
+ error_details = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
459
+ return None, error_details, "", ""
460
 
461
  finally:
462
  if temp_dir:
463
  shutil.rmtree(temp_dir, ignore_errors=True)
464
  shutil.rmtree(output_dir, ignore_errors=True)
465
 
466
+ with gr.Blocks(title="Advanced FP8 Quantizer with Mixed Precision Recovery") as demo:
467
+ gr.Markdown("# πŸ”„ Advanced FP8 Quantizer with Per-Layer Precision Recovery")
468
+ gr.Markdown("Convert `.safetensors` β†’ **FP8** + **customizable precision recovery**. Full control over LoRA and difference methods per layer.")
469
 
470
  with gr.Row():
471
  with gr.Column():
 
476
  with gr.Accordion("FP8 Settings", open=True):
477
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
478
 
479
+ with gr.Accordion("Per-Layer Recovery Configuration", open=True):
480
  gr.Markdown("""
481
+ ### Configure recovery strategy for each layer type
482
+
483
+ Format: JSON array of configuration objects:
484
+ ```json
485
+ [
486
+ {"element": "pattern1", "method": "lora", "rank": 64},
487
+ {"element": "pattern2", "method": "diff"},
488
+ {"element": "all", "method": "none"}
489
+ ]
490
+ ```
491
 
492
+ - `element`: Substring to match in weight keys (case-insensitive). Use "all" for default.
493
+ - `method`: "none" (pure FP8), "lora" (low-rank adaptation), or "diff" (difference/correction)
494
+ - `rank`: Required for "lora" method. Higher = better quality but larger file.
495
 
496
+ **Rules are applied in order** - first match wins. Always end with an "all" rule.
497
  """)
498
+
499
+ recovery_configs_json = gr.Textbox(
500
+ value="""[
501
+ {"element": "vae", "method": "diff"},
502
+ {"element": "encoder", "method": "diff"},
503
+ {"element": "decoder", "method": "diff"},
504
+ {"element": "text", "method": "lora", "rank": 64},
505
+ {"element": "emb", "method": "lora", "rank": 64},
506
+ {"element": "attn", "method": "lora", "rank": 128},
507
+ {"element": "all", "method": "none"}
508
+ ]""",
509
+ lines=10,
510
+ label="Recovery Configuration (JSON)",
511
+ interactive=True
512
  )
513
 
514
  with gr.Accordion("Authentication", open=False):
 
522
 
523
  status_output = gr.Markdown()
524
  detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10)
525
+ recovery_summary = gr.Textbox(label="Recovery Files Generated", interactive=False, lines=3)
526
 
527
  convert_btn = gr.Button("πŸš€ Convert & Upload", variant="primary")
528
  repo_link_output = gr.HTML()
 
534
  repo_url,
535
  safetensors_filename,
536
  fp8_format,
537
+ recovery_configs_json,
538
  target_type,
539
  new_repo_id,
540
  hf_token,
541
  modelscope_token,
542
  private_repo
543
  ],
544
+ outputs=[repo_link_output, status_output, detailed_log, recovery_summary],
545
  show_progress=True
546
  )
547
 
548
  gr.Examples(
549
  examples=[
550
  [
551
+ "huggingface",
552
+ "https://huggingface.co/stabilityai/sdxl-vae",
553
+ "diffusion_pytorch_model.safetensors",
554
+ "e4m3fn",
555
+ """[
556
+ {"element": "vae", "method": "diff"},
557
+ {"element": "encoder", "method": "diff"},
558
+ {"element": "decoder", "method": "diff"},
559
+ {"element": "all", "method": "none"}
560
+ ]""",
561
+ "huggingface"
562
+ ],
563
+ [
564
+ "huggingface",
565
+ "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
566
+ "model.safetensors",
567
  "e5m2",
568
+ """[
569
+ {"element": "text", "method": "lora", "rank": 64},
570
+ {"element": "emb", "method": "lora", "rank": 64},
571
+ {"element": "all", "method": "none"}
572
+ ]""",
573
  "huggingface"
574
  ],
575
  [
576
  "huggingface",
577
+ "https://huggingface.co/Yabo/FramePainter/tree/main",
578
+ "unet_diffusion_pytorch_model.safetensors",
579
  "e5m2",
580
+ """[
581
+ {"element": "attn", "method": "lora", "rank": 128},
582
+ {"element": "transformer", "method": "lora", "rank": 96},
583
+ {"element": "conv", "method": "diff"},
584
+ {"element": "resnet", "method": "diff"},
585
+ {"element": "all", "method": "none"}
586
+ ]""",
587
  "huggingface"
588
  ]
589
  ],
590
+ inputs=[source_type, repo_url, safetensors_filename, fp8_format, recovery_configs_json, target_type],
591
  label="Example Conversions"
592
  )
593
 
594
  gr.Markdown("""
595
+ ## πŸ’‘ Precision Recovery Strategy Guide
596
 
597
+ ### **LoRA Method** (best for attention/linear layers)
598
+ - **Use for**: `text`, `attn`, `transformer`, `emb`, `mlp` layers
599
+ - **Rank selection**:
600
+ - Text encoders: 64-128
601
+ - Attention blocks: 64-128
602
+ - Other linear layers: 32-64
603
+ - **Benefits**: Captures weight matrix structure, better for semantic understanding
604
+ - **Limitations**: Only works on 2D tensors, not suitable for convolutions
605
 
606
+ ### **Difference Method** (best for convolutional layers)
607
+ - **Use for**: `vae`, `encoder`, `decoder`, `conv`, `resnet` layers
608
+ - **How it works**: Stores the exact difference between FP8 and original weights
609
+ - **Benefits**: Works with any tensor shape, more accurate for spatial features
610
+ - **Limitations**: Larger file size than LoRA for equivalent quality
611
 
612
  ### **Rule Ordering Tips**
613
+ - Put specific patterns first (`vae.encoder`), general patterns last (`all`)
614
+ - Always end with an `{"element": "all", "method": "none"}` rule as fallback
615
+ - Layer names are **case-insensitive** - use lowercase patterns for matching
616
 
617
+ > **Pro Tip**: For diffusion models, use Difference for VAE/convolutional components and LoRA for text/attention components for optimal quality/size tradeoff.
618
  """)
619
 
620
  demo.launch()