codemichaeld commited on
Commit
311ef01
Β·
verified Β·
1 Parent(s): 672b8b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -390
app.py CHANGED
@@ -4,14 +4,11 @@ import tempfile
4
  import shutil
5
  import re
6
  import json
7
- import datetime
8
  from pathlib import Path
9
  from huggingface_hub import HfApi, hf_hub_download
10
  from safetensors.torch import load_file, save_file
11
  import torch
12
  import torch.nn.functional as F
13
- import traceback
14
- import math
15
  try:
16
  from modelscope.hub.file_download import model_file_download as ms_file_download
17
  from modelscope.hub.api import HubApi as ModelScopeApi
@@ -19,92 +16,21 @@ try:
19
  except ImportError:
20
  MODELScope_AVAILABLE = False
21
 
22
- def get_fp8_dtype(fp8_format):
23
- """Get torch FP8 dtype."""
24
- if fp8_format == "e5m2":
25
- return torch.float8_e5m2
26
- else:
27
- return torch.float8_e4m3fn
28
-
29
- def quantize_and_get_error(weight, fp8_dtype):
30
- """Quantize weight to FP8 and return both quantized weight and error."""
31
- weight_fp8 = weight.to(fp8_dtype)
32
- weight_dequantized = weight_fp8.to(weight.dtype)
33
- error = weight - weight_dequantized
34
- return weight_fp8, error
35
-
36
- def low_rank_decomposition_error(error_tensor, rank=32, min_error_threshold=1e-6):
37
- """Decompose error tensor with proper rank reduction."""
38
- if error_tensor.ndim not in [2, 4]:
39
  return None, None
40
 
41
  try:
42
- # Calculate error magnitude
43
- error_norm = torch.norm(error_tensor.float())
44
- if error_norm < min_error_threshold:
45
- return None, None
46
-
47
- # For 2D tensors (linear layers)
48
- if error_tensor.ndim == 2:
49
- U, S, Vh = torch.linalg.svd(error_tensor.float(), full_matrices=False)
50
-
51
- # Calculate rank based on variance explained (keep 95% of error)
52
- total_variance = torch.sum(S ** 2)
53
- cumulative = torch.cumsum(S ** 2, dim=0)
54
- keep_components = torch.sum(cumulative <= 0.95 * total_variance).item() + 1
55
-
56
- # Limit rank to much smaller than original
57
- max_rank = min(error_tensor.shape)
58
- actual_rank = min(rank, keep_components, max_rank // 2)
59
-
60
- if actual_rank < 2:
61
- return None, None
62
-
63
- A = Vh[:actual_rank, :].contiguous()
64
- B = U[:, :actual_rank] @ torch.diag(S[:actual_rank]).contiguous()
65
-
66
- return A, B
67
-
68
- # For 4D convolutions
69
- elif error_tensor.ndim == 4:
70
- out_ch, in_ch, kH, kW = error_tensor.shape
71
-
72
- # Reshape to 2D for decomposition
73
- error_2d = error_tensor.view(out_ch, in_ch * kH * kW)
74
- U, S, Vh = torch.linalg.svd(error_2d.float(), full_matrices=False)
75
-
76
- # Calculate rank based on variance explained (90% for conv)
77
- total_variance = torch.sum(S ** 2)
78
- cumulative = torch.cumsum(S ** 2, dim=0)
79
- keep_components = torch.sum(cumulative <= 0.90 * total_variance).item() + 1
80
-
81
- # Use even lower rank for conv
82
- max_rank = min(error_2d.shape)
83
- actual_rank = min(rank // 2, keep_components, max_rank // 4)
84
-
85
- if actual_rank < 2:
86
- return None, None
87
-
88
- A = Vh[:actual_rank, :].contiguous()
89
- B = U[:, :actual_rank] @ torch.diag(S[:actual_rank]).contiguous()
90
-
91
- # Reshape back for convolutional format
92
- if kH == 1 and kW == 1:
93
- B = B.view(out_ch, actual_rank, 1, 1)
94
- A = A.view(actual_rank, in_ch, 1, 1)
95
- else:
96
- B = B.view(out_ch, actual_rank, 1, 1)
97
- A = A.view(actual_rank, in_ch, kH, kW)
98
-
99
- return A, B
100
-
101
- except Exception as e:
102
- print(f"Error decomposition failed: {e}")
103
-
104
- return None, None
105
 
106
  def extract_correction_factors(original_weight, fp8_weight):
107
- """Extract simple correction factors for VAE."""
108
  with torch.no_grad():
109
  orig = original_weight.float()
110
  quant = fp8_weight.float()
@@ -112,99 +38,27 @@ def extract_correction_factors(original_weight, fp8_weight):
112
 
113
  error_norm = torch.norm(error)
114
  orig_norm = torch.norm(orig)
115
- if orig_norm > 1e-6 and error_norm / orig_norm < 0.001:
116
  return None
117
-
118
- # For 4D tensors (VAE), compute per-channel correction
119
  if orig.ndim == 4:
 
120
  channel_mean = error.mean(dim=tuple(i for i in range(1, orig.ndim)), keepdim=True)
121
  return channel_mean.to(original_weight.dtype)
 
 
122
  elif orig.ndim == 2:
123
  row_mean = error.mean(dim=1, keepdim=True)
124
  return row_mean.to(original_weight.dtype)
 
 
125
  else:
126
  return error.mean().to(original_weight.dtype)
127
 
128
- def get_architecture_settings(architecture, base_rank):
129
- """Get optimal settings for different architectures."""
130
- settings = {
131
- "text_encoder": {
132
- "rank": base_rank,
133
- "error_threshold": 5e-5,
134
- "min_rank": 8,
135
- "max_rank_factor": 0.4,
136
- "method": "lora"
137
- },
138
- "transformer": {
139
- "rank": base_rank,
140
- "error_threshold": 1e-5,
141
- "min_rank": 12,
142
- "max_rank_factor": 0.35,
143
- "method": "lora"
144
- },
145
- "vae": {
146
- "rank": base_rank // 2,
147
- "error_threshold": 1e-4,
148
- "min_rank": 4,
149
- "max_rank_factor": 0.3,
150
- "method": "correction"
151
- },
152
- "unet_conv": {
153
- "rank": base_rank // 3,
154
- "error_threshold": 2e-5,
155
- "min_rank": 8,
156
- "max_rank_factor": 0.25,
157
- "method": "lora"
158
- },
159
- "auto": {
160
- "rank": base_rank,
161
- "error_threshold": 1e-5,
162
- "min_rank": 8,
163
- "max_rank_factor": 0.3,
164
- "method": "lora"
165
- },
166
- "all": {
167
- "rank": base_rank,
168
- "error_threshold": 1e-5,
169
- "min_rank": 8,
170
- "max_rank_factor": 0.3,
171
- "method": "lora"
172
- }
173
- }
174
-
175
- return settings.get(architecture, settings["auto"])
176
-
177
- def should_process_layer(key, weight, architecture):
178
- """Determine if layer should be processed for LoRA/correction."""
179
- lower_key = key.lower()
180
-
181
- # Skip biases and normalization layers
182
- if 'bias' in key or 'norm' in key.lower() or 'bn' in key.lower():
183
- return False
184
-
185
- if weight.numel() < 100:
186
- return False
187
 
188
- # Architecture-specific filtering
189
- if architecture == "text_encoder":
190
- return ('text' in lower_key or 'emb' in lower_key or
191
- 'encoder' in lower_key or 'attn' in lower_key)
192
- elif architecture == "transformer":
193
- return ('attn' in lower_key or 'transformer' in lower_key or
194
- 'mlp' in lower_key or 'to_out' in lower_key)
195
- elif architecture == "vae":
196
- return ('vae' in lower_key or 'encoder' in lower_key or
197
- 'decoder' in lower_key or 'conv' in lower_key)
198
- elif architecture == "unet_conv":
199
- return ('conv' in lower_key or 'resnet' in lower_key or
200
- 'downsample' in lower_key or 'upsample' in lower_key)
201
- elif architecture in ["all", "auto"]:
202
- return True
203
-
204
- return False
205
-
206
- def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=128, architecture="auto", progress=gr.Progress()):
207
- progress(0.1, desc="Starting FP8 conversion with error recovery...")
208
  try:
209
  def read_safetensors_metadata(path):
210
  with open(path, 'rb') as f:
@@ -215,157 +69,122 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
215
 
216
  metadata = read_safetensors_metadata(safetensors_path)
217
  progress(0.2, desc="Loaded metadata.")
218
-
219
  state_dict = load_file(safetensors_path)
220
  progress(0.4, desc="Loaded weights.")
221
 
222
- # Auto-detect architecture if needed
223
- if architecture == "auto":
224
- model_keys = " ".join(state_dict.keys()).lower()
225
- if "vae" in model_keys or ("encoder" in model_keys and "decoder" in model_keys):
226
- architecture = "vae"
227
- elif "text" in model_keys or "emb" in model_keys:
228
- architecture = "text_encoder"
229
- elif "attn" in model_keys or "transformer" in model_keys:
230
- architecture = "transformer"
231
- elif "conv" in model_keys or "resnet" in model_keys:
232
- architecture = "unet_conv"
233
- else:
234
- architecture = "all"
235
-
236
- settings = get_architecture_settings(architecture, lora_rank)
237
- fp8_dtype = get_fp8_dtype(fp8_format)
238
 
239
  sd_fp8 = {}
240
- lora_weights = {}
241
- correction_factors = {}
242
  stats = {
243
  "total_layers": len(state_dict),
244
- "eligible_layers": 0,
245
- "layers_with_error": 0,
246
  "processed_layers": 0,
247
- "correction_layers": 0,
248
  "skipped_layers": [],
249
- "architecture": architecture,
250
- "method": settings["method"],
251
- "error_magnitudes": []
252
  }
253
 
254
  total = len(state_dict)
255
-
256
  for i, key in enumerate(state_dict):
257
  progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
258
  weight = state_dict[key]
 
259
 
260
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
261
- # Quantize to FP8 and calculate error
262
- weight_fp8, error = quantize_and_get_error(weight, fp8_dtype)
263
- sd_fp8[key] = weight_fp8
264
 
265
- # Calculate error magnitude
266
- error_norm = torch.norm(error.float())
267
- weight_norm = torch.norm(weight.float())
268
- relative_error = (error_norm / weight_norm).item() if weight_norm > 0 else 0
269
 
270
- stats["error_magnitudes"].append({
271
- "key": key,
272
- "relative_error": relative_error
273
- })
 
 
 
 
 
274
 
275
- # Check if layer should be processed
276
- should_process = should_process_layer(key, weight, architecture)
 
 
 
 
 
 
 
 
277
 
278
- if should_process:
279
- stats["eligible_layers"] += 1
280
-
281
- # Only process if error is significant
282
- if relative_error > settings["error_threshold"]:
283
- stats["layers_with_error"] += 1
284
-
285
- if settings["method"] == "correction":
286
- # Use correction factors for VAE
287
- correction = extract_correction_factors(weight, weight_fp8)
288
- if correction is not None:
289
- correction_factors[f"correction.{key}"] = correction
290
- stats["correction_layers"] += 1
291
- stats["processed_layers"] += 1
292
- else:
293
- # Use LoRA decomposition for other architectures
294
- try:
295
- A, B = low_rank_decomposition_error(
296
- error,
297
- rank=settings["rank"],
298
- min_error_threshold=settings["error_threshold"]
299
- )
300
-
301
- if A is not None and B is not None:
302
- lora_weights[f"lora_A.{key}"] = A.to(torch.float16)
303
- lora_weights[f"lora_B.{key}"] = B.to(torch.float16)
304
- stats["processed_layers"] += 1
305
- else:
306
- stats["skipped_layers"].append(f"{key}: decomposition failed")
307
- except Exception as e:
308
- stats["skipped_layers"].append(f"{key}: error - {str(e)}")
309
- else:
310
- stats["skipped_layers"].append(f"{key}: error too small ({relative_error:.6f})")
311
  else:
312
  sd_fp8[key] = weight
313
  stats["skipped_layers"].append(f"{key}: non-float dtype")
314
 
315
- # Calculate average error
316
- if stats["error_magnitudes"]:
317
- errors = [e["relative_error"] for e in stats["error_magnitudes"]]
318
- stats["avg_error"] = sum(errors) / len(errors) if errors else 0
319
- stats["max_error"] = max(errors) if errors else 0
320
-
321
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
322
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
 
323
 
324
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
325
 
326
- # Save precision recovery weights
327
- if lora_weights:
328
- lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
329
- lora_metadata = {
330
- "format": "pt",
331
- "lora_rank": str(lora_rank),
332
- "architecture": architecture,
333
- "stats": json.dumps(stats),
334
- "method": "lora"
335
- }
336
- save_file(lora_weights, lora_path, metadata=lora_metadata)
337
-
338
- if correction_factors:
339
- correction_path = os.path.join(output_dir, f"{base_name}-correction-{architecture}.safetensors")
340
- correction_metadata = {
341
  "format": "pt",
342
- "architecture": architecture,
343
- "stats": json.dumps(stats),
344
- "method": "correction"
345
- }
346
- save_file(correction_factors, correction_path, metadata=correction_metadata)
347
-
348
- progress(0.9, desc="Saved FP8 and precision recovery files.")
349
- progress(1.0, desc="βœ… FP8 + precision recovery extraction complete!")
350
 
351
- stats_msg = f"FP8 ({fp8_format}) with precision recovery saved.\n"
352
- stats_msg += f"Architecture: {architecture}\n"
353
- stats_msg += f"Method: {settings['method']}\n"
354
- stats_msg += f"Average quantization error: {stats.get('avg_error', 0):.6f}\n"
355
 
356
- if settings["method"] == "correction":
357
- stats_msg += f"Correction factors generated for {stats['correction_layers']} layers."
358
- else:
359
- stats_msg += f"LoRA generated for {stats['processed_layers']}/{stats['eligible_layers']} eligible layers (rank {lora_rank})."
360
 
361
- if stats['processed_layers'] == 0 and stats['correction_layers'] == 0:
362
- stats_msg += "\n⚠️ No precision recovery weights were generated. FP8 quantization error may be too small."
363
 
364
  return True, stats_msg, stats
365
 
366
  except Exception as e:
367
- error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
368
- return False, error_msg, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  def parse_hf_url(url):
371
  url = url.strip().rstrip("/")
@@ -428,8 +247,7 @@ def process_and_upload_fp8(
428
  repo_url,
429
  safetensors_filename,
430
  fp8_format,
431
- lora_rank,
432
- architecture,
433
  target_type,
434
  new_repo_id,
435
  hf_token,
@@ -443,8 +261,10 @@ def process_and_upload_fp8(
443
  return None, "❌ Hugging Face token required for source.", ""
444
  if target_type == "huggingface" and not hf_token:
445
  return None, "❌ Hugging Face token required for target.", ""
446
- if lora_rank < 8:
447
- return None, "❌ LoRA rank must be at least 8.", ""
 
 
448
 
449
  temp_dir = None
450
  output_dir = tempfile.mkdtemp()
@@ -454,9 +274,9 @@ def process_and_upload_fp8(
454
  source_type, repo_url, safetensors_filename, hf_token, progress
455
  )
456
 
457
- progress(0.25, desc="Converting to FP8 with precision recovery...")
458
- success, msg, stats = convert_safetensors_to_fp8_with_lora(
459
- safetensors_path, output_dir, fp8_format, lora_rank, architecture, progress
460
  )
461
 
462
  if not success:
@@ -469,16 +289,7 @@ def process_and_upload_fp8(
469
 
470
  base_name = os.path.splitext(safetensors_filename)[0]
471
  fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
472
-
473
- # Determine which precision recovery file was generated
474
- precision_recovery_file = ""
475
- precision_recovery_type = ""
476
- if stats.get("method") == "correction" and stats.get("correction_layers", 0) > 0:
477
- precision_recovery_file = f"{base_name}-correction-{architecture}.safetensors"
478
- precision_recovery_type = "Correction Factors"
479
- elif stats.get("method") == "lora" and stats.get("processed_layers", 0) > 0:
480
- precision_recovery_file = f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors"
481
- precision_recovery_type = "LoRA"
482
 
483
  readme = f"""---
484
  library_name: diffusers
@@ -486,60 +297,49 @@ tags:
486
  - fp8
487
  - safetensors
488
  - precision-recovery
489
- - diffusion
490
  - converted-by-gradio
491
  ---
492
- # FP8 Model with Precision Recovery
 
493
  - **Source**: `{repo_url}`
494
  - **File**: `{safetensors_filename}`
495
  - **FP8 Format**: `{fp8_format.upper()}`
496
- - **Architecture**: {architecture}
497
- - **Precision Recovery Type**: {precision_recovery_type}
498
- - **Precision Recovery File**: `{precision_recovery_file}` if available
499
- - **FP8 File**: `{fp8_filename}`
500
 
501
- ## Usage (Inference)
 
 
 
 
 
502
  ```python
503
  from safetensors.torch import load_file
504
  import torch
505
 
506
- # Load FP8 model
507
  fp8_state = load_file("{fp8_filename}")
 
508
 
509
- # Load precision recovery file if available
510
- recovery_state = {{}}
511
- if "{precision_recovery_file}":
512
- recovery_state = load_file("{precision_recovery_file}")
513
-
514
- # Reconstruct high-precision weights
515
  reconstructed = {{}}
516
  for key in fp8_state:
517
- # Dequantize FP8 to target precision
518
- fp_weight = fp8_state[key].to(torch.float32)
519
 
520
- if recovery_state:
521
- # For LoRA approach
522
- if f"lora_A.{{key}}" in recovery_state and f"lora_B.{{key}}" in recovery_state:
523
- A = recovery_state[f"lora_A.{{key}}"].to(torch.float32)
524
- B = recovery_state[f"lora_B.{{key}}"].to(torch.float32)
525
- error_correction = B @ A
526
- reconstructed[key] = fp_weight + error_correction
527
- # For correction factor approach
528
- elif f"correction.{{key}}" in recovery_state:
529
- correction = recovery_state[f"correction.{{key}}"].to(torch.float32)
530
- reconstructed[key] = fp_weight + correction
531
- else:
532
- reconstructed[key] = fp_weight
533
- else:
534
- reconstructed[key] = fp_weight
535
-
536
- print("Model reconstructed with FP8 error recovery")
537
  ```
538
-
539
- > **Note**: This precision recovery targets FP8 quantization errors.
540
- > Average quantization error: {stats.get('avg_error', 0):.6f}
541
  """
542
-
543
  with open(os.path.join(output_dir, "README.md"), "w") as f:
544
  f.write(readme)
545
 
@@ -553,31 +353,24 @@ print("Model reconstructed with FP8 error recovery")
553
  )
554
 
555
  progress(1.0, desc="βœ… Done!")
556
-
557
  result_html = f"""
558
  βœ… Success!
559
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
560
- Includes: FP8 model + precision recovery ({precision_recovery_type}).
561
- Average quantization error: {stats.get('avg_error', 0):.6f}
562
  """
563
-
564
- if stats['processed_layers'] > 0 or stats['correction_layers'] > 0:
565
- result_html += f"<br>Precision recovery applied to {stats['processed_layers'] + stats['correction_layers']} layers."
566
-
567
- return gr.HTML(result_html), "βœ… FP8 + precision recovery upload successful!", msg
568
-
569
  except Exception as e:
570
- error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
571
- return None, error_msg, ""
572
 
573
  finally:
574
  if temp_dir:
575
  shutil.rmtree(temp_dir, ignore_errors=True)
576
  shutil.rmtree(output_dir, ignore_errors=True)
577
 
578
- with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
579
- gr.Markdown("# πŸ”„ FP8 Converter with Architecture-Specific Precision Recovery")
580
- gr.Markdown("Convert models to **FP8** with **error-based precision recovery**.")
581
 
582
  with gr.Row():
583
  with gr.Column():
@@ -585,31 +378,33 @@ with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
585
  repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
586
  safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors")
587
 
588
- with gr.Accordion("Advanced Settings", open=True):
589
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
590
- lora_rank = gr.Slider(minimum=8, maximum=256, step=8, value=128,
591
- label="LoRA Rank (for text/transformers)")
592
- architecture = gr.Dropdown(
593
- choices=[
594
- ("Auto-detect architecture", "auto"),
595
- ("Text Encoder (LoRA)", "text_encoder"),
596
- ("Transformer blocks (LoRA)", "transformer"),
597
- ("VAE (Correction Factors)", "vae"),
598
- ("UNet Convolutions (LoRA)", "unet_conv"),
599
- ("All layers (LoRA where applicable)", "all")
600
- ],
601
- value="auto",
602
- label="Target Architecture"
 
 
 
603
  )
604
 
605
  with gr.Accordion("Authentication", open=False):
606
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
607
- modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password",
608
- visible=MODELScope_AVAILABLE)
609
 
610
  with gr.Column():
611
  target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target")
612
- new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8-precision")
613
  private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False)
614
 
615
  status_output = gr.Markdown()
@@ -625,8 +420,7 @@ with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
625
  repo_url,
626
  safetensors_filename,
627
  fp8_format,
628
- lora_rank,
629
- architecture,
630
  target_type,
631
  new_repo_id,
632
  hf_token,
@@ -639,39 +433,46 @@ with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
639
 
640
  gr.Examples(
641
  examples=[
642
- ["huggingface", "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
643
- "model.safetensors", "e5m2", 96, "text_encoder"],
644
- ["huggingface", "https://huggingface.co/stabilityai/sdxl-vae",
645
- "diffusion_pytorch_model.safetensors", "e4m3fn", 64, "vae"],
646
- ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main",
647
- "unet_diffusion_pytorch_model.safetensors", "e5m2", 128, "transformer"]
 
 
 
 
 
 
 
 
 
 
648
  ],
649
- inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture],
650
  label="Example Conversions"
651
  )
652
 
653
  gr.Markdown("""
654
- ## 🎯 What This Tool Does
655
-
656
- Unlike traditional LoRA fine-tuning, this tool:
657
-
658
- 1. **Quantizes** the model to FP8 (loses precision)
659
- 2. **Measures** the quantization error for each weight
660
- 3. **Extracts recovery weights** that specifically recover this error
661
- 4. **Only applies** recovery where error is significant (>0.001%)
662
 
663
- ## πŸ’‘ Recommended Settings
 
 
 
664
 
665
- - **Text Encoders**: rank 64-96 (text is sensitive)
666
- - **Transformers**: rank 96-128
667
- - **VAE**: Uses correction factors (no rank needed)
668
- - **UNet Convolutions**: rank 32-64
669
 
670
- ## ⚠️ Important Notes
 
 
 
671
 
672
- - This recovers **FP8 quantization errors**, not fine-tuning changes
673
- - If FP8 error is tiny (<0.0001%), recovery may not be generated
674
- - Higher rank β‰  better for error recovery (use recommended ranges)
675
  """)
676
 
677
  demo.launch()
 
4
  import shutil
5
  import re
6
  import json
 
7
  from pathlib import Path
8
  from huggingface_hub import HfApi, hf_hub_download
9
  from safetensors.torch import load_file, save_file
10
  import torch
11
  import torch.nn.functional as F
 
 
12
  try:
13
  from modelscope.hub.file_download import model_file_download as ms_file_download
14
  from modelscope.hub.api import HubApi as ModelScopeApi
 
16
  except ImportError:
17
  MODELScope_AVAILABLE = False
18
 
19
+ def low_rank_decomposition(weight, rank=64):
20
+ """Standard LoRA decomposition for 2D tensors only."""
21
+ if weight.ndim != 2:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  return None, None
23
 
24
  try:
25
+ U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
26
+ U = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
27
+ Vh = torch.diag(torch.sqrt(S[:rank])) @ Vh[:rank, :]
28
+ return U.contiguous(), Vh.contiguous()
29
+ except Exception:
30
+ return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def extract_correction_factors(original_weight, fp8_weight):
33
+ """Extract per-channel/tensor correction factors (difference method)."""
34
  with torch.no_grad():
35
  orig = original_weight.float()
36
  quant = fp8_weight.float()
 
38
 
39
  error_norm = torch.norm(error)
40
  orig_norm = torch.norm(orig)
41
+ if orig_norm > 1e-6 and error_norm / orig_norm < 0.01:
42
  return None
43
+
44
+ # For 4D tensors (VAE/conv layers)
45
  if orig.ndim == 4:
46
+ channel_dim = 0
47
  channel_mean = error.mean(dim=tuple(i for i in range(1, orig.ndim)), keepdim=True)
48
  return channel_mean.to(original_weight.dtype)
49
+
50
+ # For 2D tensors (linear layers)
51
  elif orig.ndim == 2:
52
  row_mean = error.mean(dim=1, keepdim=True)
53
  return row_mean.to(original_weight.dtype)
54
+
55
+ # For 1D tensors (bias, etc.)
56
  else:
57
  return error.mean().to(original_weight.dtype)
58
 
59
+ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_format, recovery_config, progress=gr.Progress()):
60
+ progress(0.1, desc="Starting FP8 conversion with precision recovery...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
  def read_safetensors_metadata(path):
64
  with open(path, 'rb') as f:
 
69
 
70
  metadata = read_safetensors_metadata(safetensors_path)
71
  progress(0.2, desc="Loaded metadata.")
 
72
  state_dict = load_file(safetensors_path)
73
  progress(0.4, desc="Loaded weights.")
74
 
75
+ if fp8_format == "e5m2":
76
+ fp8_dtype = torch.float8_e5m2
77
+ else:
78
+ fp8_dtype = torch.float8_e4m3fn
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  sd_fp8 = {}
81
+ recovery_weights = {}
 
82
  stats = {
83
  "total_layers": len(state_dict),
 
 
84
  "processed_layers": 0,
 
85
  "skipped_layers": [],
86
+ "recovery_type_counts": {"lora": 0, "diff": 0}
 
 
87
  }
88
 
89
  total = len(state_dict)
 
90
  for i, key in enumerate(state_dict):
91
  progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
92
  weight = state_dict[key]
93
+ lower_key = key.lower()
94
 
95
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
96
+ fp8_weight = weight.to(fp8_dtype)
97
+ sd_fp8[key] = fp8_weight
 
98
 
99
+ # Match key against recovery config rules
100
+ recovery_method = "none"
101
+ lora_rank = 64
 
102
 
103
+ for rule in recovery_config:
104
+ element_pattern = rule.get("element", "").lower()
105
+ method = rule.get("method", "none")
106
+
107
+ if element_pattern == "all" or element_pattern in lower_key:
108
+ recovery_method = method
109
+ if method == "lora":
110
+ lora_rank = rule.get("rank", 64)
111
+ break
112
 
113
+ if recovery_method == "lora" and weight.ndim == 2 and min(weight.shape) > lora_rank:
114
+ try:
115
+ U, V = low_rank_decomposition(weight, rank=lora_rank)
116
+ if U is not None and V is not None:
117
+ recovery_weights[f"lora_A.{key}"] = U.to(torch.float16)
118
+ recovery_weights[f"lora_B.{key}"] = V.to(torch.float16)
119
+ stats["processed_layers"] += 1
120
+ stats["recovery_type_counts"]["lora"] += 1
121
+ except Exception:
122
+ stats["skipped_layers"].append(f"{key}: lora failed")
123
 
124
+ elif recovery_method == "diff":
125
+ try:
126
+ corr = extract_correction_factors(weight, fp8_weight)
127
+ if corr is not None:
128
+ recovery_weights[f"diff.{key}"] = corr
129
+ stats["processed_layers"] += 1
130
+ stats["recovery_type_counts"]["diff"] += 1
131
+ except Exception:
132
+ stats["skipped_layers"].append(f"{key}: diff failed")
133
+
134
+ else:
135
+ stats["skipped_layers"].append(f"{key}: {recovery_method}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  else:
137
  sd_fp8[key] = weight
138
  stats["skipped_layers"].append(f"{key}: non-float dtype")
139
 
 
 
 
 
 
 
140
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
141
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
142
+ recovery_path = os.path.join(output_dir, f"{base_name}-recovery.safetensors")
143
 
144
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
145
 
146
+ if recovery_weights:
147
+ save_file(recovery_weights, recovery_path, metadata={
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  "format": "pt",
149
+ "fp8_format": fp8_format,
150
+ "recovery_config": json.dumps(recovery_config),
151
+ "stats": json.dumps(stats)
152
+ })
 
 
 
 
153
 
154
+ progress(0.9, desc="Saved FP8 and recovery files.")
155
+ progress(1.0, desc="βœ… FP8 + recovery extraction complete!")
 
 
156
 
157
+ stats_msg = f"FP8 ({fp8_format}) and recovery saved.\n"
158
+ stats_msg += f"- Total layers: {stats['total_layers']}\n"
159
+ stats_msg += f"- Processed: {stats['processed_layers']} ({stats['recovery_type_counts']['lora']} LoRA + {stats['recovery_type_counts']['diff']} Diff)\n"
 
160
 
161
+ if stats["processed_layers"] == 0:
162
+ stats_msg += "\n⚠️ No recovery weights generated. Check your rules and rank settings."
163
 
164
  return True, stats_msg, stats
165
 
166
  except Exception as e:
167
+ return False, str(e), None
168
+
169
+ def generate_config_from_rules(rules_input):
170
+ """Parse multi-line rule input into config."""
171
+ config = []
172
+ for line in rules_input.strip().split('\n'):
173
+ line = line.strip()
174
+ if not line or line.startswith('#'):
175
+ continue
176
+ parts = [p.strip() for p in line.split(',')]
177
+ if len(parts) >= 2:
178
+ element = parts[0]
179
+ method = parts[1].lower()
180
+ rank = 64
181
+ if method == "lora" and len(parts) >= 3:
182
+ try:
183
+ rank = int(parts[2])
184
+ except ValueError:
185
+ pass
186
+ config.append({"element": element, "method": method, "rank": rank})
187
+ return config
188
 
189
  def parse_hf_url(url):
190
  url = url.strip().rstrip("/")
 
247
  repo_url,
248
  safetensors_filename,
249
  fp8_format,
250
+ recovery_rules,
 
251
  target_type,
252
  new_repo_id,
253
  hf_token,
 
261
  return None, "❌ Hugging Face token required for source.", ""
262
  if target_type == "huggingface" and not hf_token:
263
  return None, "❌ Hugging Face token required for target.", ""
264
+
265
+ recovery_config = generate_config_from_rules(recovery_rules)
266
+ if not recovery_config:
267
+ recovery_config = [{"element": "all", "method": "none"}]
268
 
269
  temp_dir = None
270
  output_dir = tempfile.mkdtemp()
 
274
  source_type, repo_url, safetensors_filename, hf_token, progress
275
  )
276
 
277
+ progress(0.25, desc="Converting to FP8 with recovery...")
278
+ success, msg, stats = convert_safetensors_to_fp8_with_recovery(
279
+ safetensors_path, output_dir, fp8_format, recovery_config, progress
280
  )
281
 
282
  if not success:
 
289
 
290
  base_name = os.path.splitext(safetensors_filename)[0]
291
  fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
292
+ recovery_filename = f"{base_name}-recovery.safetensors"
 
 
 
 
 
 
 
 
 
293
 
294
  readme = f"""---
295
  library_name: diffusers
 
297
  - fp8
298
  - safetensors
299
  - precision-recovery
300
+ - mixed-method
301
  - converted-by-gradio
302
  ---
303
+ # FP8 Model with Custom Precision Recovery
304
+
305
  - **Source**: `{repo_url}`
306
  - **File**: `{safetensors_filename}`
307
  - **FP8 Format**: `{fp8_format.upper()}`
308
+ - **Recovery File**: `{recovery_filename}` (contains both LoRA and Difference weights)
 
 
 
309
 
310
+ ## Recovery Rules Used
311
+ ```
312
+ {recovery_rules}
313
+ ```
314
+
315
+ ## Usage
316
  ```python
317
  from safetensors.torch import load_file
318
  import torch
319
 
 
320
  fp8_state = load_file("{fp8_filename}")
321
+ recovery_state = load_file("{recovery_filename}")
322
 
 
 
 
 
 
 
323
  reconstructed = {{}}
324
  for key in fp8_state:
325
+ fp8_weight = fp8_state[key].to(torch.float32)
 
326
 
327
+ # Apply LoRA if present
328
+ if f"lora_A.{{key}}" in recovery_state and f"lora_B.{{key}}" in recovery_state:
329
+ A = recovery_state[f"lora_A.{{key}}"].to(torch.float32)
330
+ B = recovery_state[f"lora_B.{{key}}"].to(torch.float32)
331
+ lora_weight = B @ A
332
+ fp8_weight = fp8_weight + lora_weight
333
+
334
+ # Apply Difference if present
335
+ if f"diff.{{key}}" in recovery_state:
336
+ diff = recovery_state[f"diff.{{key}}"].to(torch.float32)
337
+ fp8_weight = fp8_weight + diff
338
+
339
+ reconstructed[key] = fp8_weight
 
 
 
 
340
  ```
 
 
 
341
  """
342
+
343
  with open(os.path.join(output_dir, "README.md"), "w") as f:
344
  f.write(readme)
345
 
 
353
  )
354
 
355
  progress(1.0, desc="βœ… Done!")
 
356
  result_html = f"""
357
  βœ… Success!
358
  Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
359
+ Includes FP8 + custom recovery weights.
 
360
  """
361
+ return gr.HTML(result_html), "βœ… FP8 + recovery upload successful!", msg
362
+
 
 
 
 
363
  except Exception as e:
364
+ return None, f"❌ Error: {str(e)}", ""
 
365
 
366
  finally:
367
  if temp_dir:
368
  shutil.rmtree(temp_dir, ignore_errors=True)
369
  shutil.rmtree(output_dir, ignore_errors=True)
370
 
371
+ with gr.Blocks(title="FP8 + Custom Recovery Extractor") as demo:
372
+ gr.Markdown("# πŸ”„ FP8 Quantizer with Per-Layer Recovery Control")
373
+ gr.Markdown("Specify **exact recovery method per layer/tensor** using pattern matching. Supports LoRA and Difference methods simultaneously.")
374
 
375
  with gr.Row():
376
  with gr.Column():
 
378
  repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
379
  safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors")
380
 
381
+ with gr.Accordion("FP8 Settings", open=True):
382
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
383
+
384
+ with gr.Accordion("Recovery Rules (Layer/Tensor Level)", open=True):
385
+ gr.Markdown("""
386
+ Define recovery rules **one per line** in format:
387
+ `layer_pattern, method [, rank]`
388
+
389
+ - `layer_pattern`: substring to match in weight key (case-insensitive)
390
+ - `method`: `lora` or `diff` or `none`
391
+ - `rank`: LoRA rank (only for `lora` method)
392
+
393
+ **Rules are applied in order** – first match wins.
394
+ """)
395
+ recovery_rules = gr.Textbox(
396
+ value="vae, diff\nencoder, diff\ndecoder, diff\ntext, lora, 64\nattn, lora, 128\nall, none",
397
+ lines=8,
398
+ label="Recovery Rules"
399
  )
400
 
401
  with gr.Accordion("Authentication", open=False):
402
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
403
+ modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
 
404
 
405
  with gr.Column():
406
  target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target")
407
+ new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8")
408
  private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False)
409
 
410
  status_output = gr.Markdown()
 
420
  repo_url,
421
  safetensors_filename,
422
  fp8_format,
423
+ recovery_rules,
 
424
  target_type,
425
  new_repo_id,
426
  hf_token,
 
433
 
434
  gr.Examples(
435
  examples=[
436
+ [
437
+ "huggingface",
438
+ "https://huggingface.co/stabilityai/sdxl-vae",
439
+ "diffusion_pytorch_model.safetensors",
440
+ "e5m2",
441
+ "vae, diff\nencoder, diff\ndecoder, diff\nall, none",
442
+ "huggingface"
443
+ ],
444
+ [
445
+ "huggingface",
446
+ "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
447
+ "model.safetensors",
448
+ "e5m2",
449
+ "text, lora, 64\nemb, lora, 64\nall, none",
450
+ "huggingface"
451
+ ]
452
  ],
453
+ inputs=[source_type, repo_url, safetensors_filename, fp8_format, recovery_rules, target_type],
454
  label="Example Conversions"
455
  )
456
 
457
  gr.Markdown("""
458
+ ## πŸ’‘ Recovery Strategy Guide
 
 
 
 
 
 
 
459
 
460
+ ### **Difference Method (Recommended for VAE/Convs)**
461
+ - Use for: `vae`, `encoder`, `decoder`, `conv` layers
462
+ - Captures exact quantization error
463
+ - Works with 4D tensors that LoRA cannot handle
464
 
465
+ ### **LoRA Method (Recommended for Attention/Linear)**
466
+ - Use for: `text`, `attn`, `mlp`, `transformer` layers
467
+ - Use rank 32-128 depending on layer importance
468
+ - Only works on 2D tensors
469
 
470
+ ### **Rule Ordering Tips**
471
+ - Put specific patterns first (`vae.encoder`) before general ones (`vae`)
472
+ - End with `all, none` to set default behavior
473
+ - Layer names are **case-insensitive**
474
 
475
+ > This implementation restores the successful VAE difference method while adding full per-layer control.
 
 
476
  """)
477
 
478
  demo.launch()