codemichaeld commited on
Commit
9de10a2
Β·
verified Β·
1 Parent(s): 4f1244e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +414 -183
app.py CHANGED
@@ -9,6 +9,7 @@ from huggingface_hub import HfApi, hf_hub_download
9
  from safetensors.torch import load_file, save_file
10
  import torch
11
  import torch.nn.functional as F
 
12
  try:
13
  from modelscope.hub.file_download import model_file_download as ms_file_download
14
  from modelscope.hub.api import HubApi as ModelScopeApi
@@ -17,25 +18,44 @@ except ImportError:
17
  MODELScope_AVAILABLE = False
18
 
19
  def low_rank_decomposition(weight, rank=64):
20
- """Standard LoRA decomposition for 2D tensors."""
21
- if weight.ndim != 2:
22
- return None, None
23
-
 
 
 
24
  try:
25
- weight_f32 = weight.float()
26
- U, S, Vh = torch.linalg.svd(weight_f32, full_matrices=False)
27
-
28
- actual_rank = min(rank, len(S))
29
- if actual_rank < 4:
30
- return None, None
31
-
32
- # Standard LoRA factorization: W = W_B @ W_A
33
- W_A = Vh[:actual_rank, :].contiguous() # [rank, in_features]
34
- W_B = U[:, :actual_rank] @ torch.diag(S[:actual_rank]) # [out_features, rank]
35
-
36
- return W_A.to(torch.float16), W_B.to(torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  except Exception as e:
38
- print(f"Decomposition error: {e}")
 
39
  return None, None
40
 
41
  def extract_correction_factors(original_weight, fp8_weight):
@@ -72,36 +92,68 @@ def extract_correction_factors(original_weight, fp8_weight):
72
  else:
73
  return error.mean().to(original_weight.dtype)
74
 
75
- def analyze_model_architecture(state_dict):
76
- """Auto-detect model architecture and components."""
77
- keys = " ".join(state_dict.keys()).lower()
78
- components = {
79
- "text_encoder": False,
80
- "unet": False,
81
- "vae": False,
82
- "clip": False,
83
- "transformer": False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # Detect components based on key patterns
87
- if "text" in keys or "emb" in keys or ("encoder" in keys and "vae" not in keys):
88
- components["text_encoder"] = True
89
- if "clip" in keys or "vision" in keys:
90
- components["clip"] = True
91
 
92
- if "unet" in keys or ("down_blocks" in keys and "up_blocks" in keys) or ("input_blocks" in keys and "output_blocks" in keys):
93
- components["unet"] = True
94
- if "transformer" in keys or "attn" in keys:
95
- components["transformer"] = True
96
 
97
- if "vae" in keys or ("encoder" in keys and "decoder" in keys) or "quant_conv" in keys or "post_quant" in keys:
98
- components["vae"] = True
 
 
 
99
 
100
- return components
101
 
102
  def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_format,
103
- recovery_configs, progress=gr.Progress()):
104
- """Convert model to FP8 with customizable per-element recovery strategies."""
105
  progress(0.1, desc="Starting FP8 conversion with precision recovery...")
106
  try:
107
  def read_safetensors_metadata(path):
@@ -118,10 +170,6 @@ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_f
118
  state_dict = load_file(safetensors_path)
119
  progress(0.3, desc="Loaded model weights.")
120
 
121
- # Auto-detect architecture
122
- detected_components = analyze_model_architecture(state_dict)
123
- print(f"Detected components: {detected_components}")
124
-
125
  # Setup FP8 format
126
  fp8_dtype = torch.float8_e5m2 if fp8_format == "e5m2" else torch.float8_e4m3fn
127
 
@@ -132,27 +180,17 @@ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_f
132
  "total_layers": len(state_dict),
133
  "processed_layers": 0,
134
  "skipped_layers": [],
135
- "detected_components": detected_components,
136
- "recovery_counts": {"lora": 0, "diff": 0}
137
  }
138
 
139
- # Create a mapping from layer keys to recovery config
140
- layer_recovery_map = {}
141
- for config in recovery_configs:
142
- element_pattern = config["element"].lower()
143
- for key in state_dict:
144
- if element_pattern == "all" or element_pattern in key.lower():
145
- # Only set if not already set (first match wins)
146
- if key not in layer_recovery_map:
147
- layer_recovery_map[key] = config
148
-
149
  # Process each tensor
150
  total = len(state_dict)
151
  for i, key in enumerate(state_dict):
152
  progress(0.3 + 0.5 * (i / total), desc=f"Processing {i+1}/{total}: {key.split('.')[-1]}")
153
  weight = state_dict[key]
 
154
 
155
- # Convert to FP8
156
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
157
  fp8_weight = weight.to(fp8_dtype)
158
  sd_fp8[key] = fp8_weight
@@ -161,43 +199,53 @@ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_f
161
  stats["skipped_layers"].append(f"{key}: non-float dtype")
162
  continue
163
 
164
- # Get recovery config for this layer
165
- recovery_config = layer_recovery_map.get(key)
166
- if not recovery_config or recovery_config["method"] == "none":
167
- stats["skipped_layers"].append(f"{key}: no recovery configured")
168
- continue
169
 
170
- try:
171
- method = recovery_config["method"]
172
- if method == "lora" and weight.ndim == 2:
173
- # LoRA recovery for 2D tensors only
174
- rank = recovery_config.get("rank", 64)
175
- # Adjust rank for smaller matrices
176
- adjusted_rank = min(rank, min(weight.shape) // 2)
177
- if adjusted_rank >= 4:
178
- A, B = low_rank_decomposition(weight, rank=adjusted_rank)
179
- if A is not None and B is not None:
180
- recovery_weights[f"lora_A.{key}"] = A
181
- recovery_weights[f"lora_B.{key}"] = B
182
- stats["processed_layers"] += 1
183
- stats["recovery_counts"]["lora"] += 1
184
- continue
185
-
186
- if method == "diff":
187
- # Difference/correction recovery for any tensor type
188
- corr = extract_correction_factors(weight, fp8_weight)
189
- if corr is not None:
190
- recovery_weights[f"diff.{key}"] = corr
191
- stats["processed_layers"] += 1
192
- stats["recovery_counts"]["diff"] += 1
193
- continue
194
-
195
- # If we get here, recovery was configured but couldn't be applied
196
- reason = "2D tensor required" if method == "lora" and weight.ndim != 2 else "decomposition failed"
197
- stats["skipped_layers"].append(f"{key}: {method} recovery failed ({reason})")
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- except Exception as e:
200
- stats["skipped_layers"].append(f"{key}: error - {str(e)}")
 
201
 
202
  # Save FP8 model
203
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
@@ -211,7 +259,7 @@ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_f
211
  recovery_metadata = {
212
  "format": "pt",
213
  "fp8_format": fp8_format,
214
- "recovery_config": json.dumps(recovery_configs),
215
  "stats": json.dumps(stats)
216
  }
217
  save_file(recovery_weights, recovery_path, metadata=recovery_metadata)
@@ -225,6 +273,16 @@ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_f
225
  stats_msg += f" - LoRA recovery: {stats['recovery_counts']['lora']}\n"
226
  stats_msg += f" - Difference recovery: {stats['recovery_counts']['diff']}\n"
227
 
 
 
 
 
 
 
 
 
 
 
228
  if not recovery_weights:
229
  stats_msg += "\n⚠️ No recovery weights were generated. All layers use pure FP8."
230
 
@@ -232,9 +290,8 @@ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_f
232
  return True, stats_msg, stats, fp8_path, recovery_path
233
 
234
  except Exception as e:
235
- import traceback
236
- error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
237
- return False, error_msg, None, None, None
238
 
239
  def parse_hf_url(url):
240
  url = url.strip().rstrip("/")
@@ -292,12 +349,166 @@ def upload_to_target(target_type, new_repo_id, output_dir, fp8_format, hf_token=
292
  else:
293
  raise ValueError("Unknown target")
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  def process_and_upload_fp8(
296
  source_type,
297
  repo_url,
298
  safetensors_filename,
299
  fp8_format,
300
- recovery_configs_json,
301
  target_type,
302
  new_repo_id,
303
  hf_token,
@@ -312,20 +523,18 @@ def process_and_upload_fp8(
312
  if target_type == "huggingface" and not hf_token:
313
  return None, "❌ Hugging Face token required for target.", "", ""
314
 
315
- # Parse recovery configs
316
  try:
317
- recovery_configs = json.loads(recovery_configs_json)
318
  except json.JSONDecodeError:
319
- return None, "❌ Invalid recovery configuration JSON.", "", ""
320
 
321
- # Validate config format
322
  valid_methods = ["none", "lora", "diff"]
323
- for config in recovery_configs:
324
- if "element" not in config or "method" not in config:
325
- return None, "❌ Invalid config format: each config needs 'element' and 'method'", "", ""
326
- if config["method"] not in valid_methods:
327
- return None, f"❌ Invalid method: {config['method']}. Use 'none', 'lora', or 'diff'", "", ""
328
- if config["method"] == "lora" and "rank" not in config:
329
  return None, "❌ LoRA method requires 'rank' parameter", "", ""
330
 
331
  temp_dir = None
@@ -338,7 +547,7 @@ def process_and_upload_fp8(
338
 
339
  progress(0.2, desc="Converting to FP8 with precision recovery...")
340
  success, msg, stats, fp8_path, recovery_path = convert_safetensors_to_fp8_with_recovery(
341
- safetensors_path, output_dir, fp8_format, recovery_configs, progress
342
  )
343
 
344
  if not success:
@@ -363,16 +572,16 @@ tags:
363
  - mixed-method
364
  - converted-by-gradio
365
  ---
366
- # FP8 Model with Mixed Precision Recovery
367
  - **Source**: `{repo_url}`
368
  - **Original File**: `{safetensors_filename}`
369
  - **FP8 Format**: `{fp8_format.upper()}`
370
  - **FP8 File**: `{fp8_filename}`
371
  - **Recovery File**: `{recovery_filename if recovery_filename else "None"}`
372
 
373
- ## Recovery Configuration
374
  ```json
375
- {json.dumps(recovery_configs, indent=2)}
376
  ```
377
 
378
  ## Usage (Inference)
@@ -392,16 +601,19 @@ for key in fp8_state:
392
  fp8_weight = fp8_state[key].to(torch.float32) # Convert to float32 for computation
393
 
394
  # Apply LoRA recovery if available
395
- if f"lora_A.{{key}}" in recovery_state and f"lora_B.{{key}}" in recovery_state:
396
- A = recovery_state[f"lora_A.{{key}}"].to(torch.float32)
397
- B = recovery_state[f"lora_B.{{key}}"].to(torch.float32)
 
 
398
  # Reconstruct the low-rank approximation
399
  lora_weight = B @ A
400
  fp8_weight = fp8_weight + lora_weight
401
 
402
  # Apply difference recovery if available
403
- if f"diff.{{key}}" in recovery_state:
404
- diff = recovery_state[f"diff.{{key}}"].to(torch.float32)
 
405
  fp8_weight = fp8_weight + diff
406
 
407
  reconstructed[key] = fp8_weight
@@ -454,18 +666,17 @@ Includes:
454
  recovery_details)
455
 
456
  except Exception as e:
457
- import traceback
458
- error_details = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
459
- return None, error_details, "", ""
460
 
461
  finally:
462
  if temp_dir:
463
  shutil.rmtree(temp_dir, ignore_errors=True)
464
  shutil.rmtree(output_dir, ignore_errors=True)
465
 
466
- with gr.Blocks(title="Advanced FP8 Quantizer with Mixed Precision Recovery") as demo:
467
- gr.Markdown("# πŸ”„ Advanced FP8 Quantizer with Per-Layer Precision Recovery")
468
- gr.Markdown("Convert `.safetensors` β†’ **FP8** + **customizable precision recovery**. Full control over LoRA and difference methods per layer.")
469
 
470
  with gr.Row():
471
  with gr.Column():
@@ -476,40 +687,69 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Mixed Precision Recovery") as
476
  with gr.Accordion("FP8 Settings", open=True):
477
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
478
 
479
- with gr.Accordion("Per-Layer Recovery Configuration", open=True):
480
  gr.Markdown("""
481
- ### Configure recovery strategy for each layer type
482
 
483
- Format: JSON array of configuration objects:
484
  ```json
485
  [
486
- {"element": "pattern1", "method": "lora", "rank": 64},
487
- {"element": "pattern2", "method": "diff"},
488
- {"element": "all", "method": "none"}
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  ]
490
  ```
491
 
492
- - `element`: Substring to match in weight keys (case-insensitive). Use "all" for default.
 
 
 
 
 
493
  - `method`: "none" (pure FP8), "lora" (low-rank adaptation), or "diff" (difference/correction)
494
- - `rank`: Required for "lora" method. Higher = better quality but larger file.
495
 
496
- **Rules are applied in order** - first match wins. Always end with an "all" rule.
497
  """)
498
 
499
- recovery_configs_json = gr.Textbox(
500
- value="""[
501
- {"element": "vae", "method": "diff"},
502
- {"element": "encoder", "method": "diff"},
503
- {"element": "decoder", "method": "diff"},
504
- {"element": "text", "method": "lora", "rank": 64},
505
- {"element": "emb", "method": "lora", "rank": 64},
506
- {"element": "attn", "method": "lora", "rank": 128},
507
- {"element": "all", "method": "none"}
508
- ]""",
509
- lines=10,
510
- label="Recovery Configuration (JSON)",
511
  interactive=True
512
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
  with gr.Accordion("Authentication", open=False):
515
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
@@ -534,7 +774,7 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Mixed Precision Recovery") as
534
  repo_url,
535
  safetensors_filename,
536
  fp8_format,
537
- recovery_configs_json,
538
  target_type,
539
  new_repo_id,
540
  hf_token,
@@ -552,12 +792,7 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Mixed Precision Recovery") as
552
  "https://huggingface.co/stabilityai/sdxl-vae",
553
  "diffusion_pytorch_model.safetensors",
554
  "e4m3fn",
555
- """[
556
- {"element": "vae", "method": "diff"},
557
- {"element": "encoder", "method": "diff"},
558
- {"element": "decoder", "method": "diff"},
559
- {"element": "all", "method": "none"}
560
- ]""",
561
  "huggingface"
562
  ],
563
  [
@@ -565,11 +800,7 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Mixed Precision Recovery") as
565
  "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
566
  "model.safetensors",
567
  "e5m2",
568
- """[
569
- {"element": "text", "method": "lora", "rank": 64},
570
- {"element": "emb", "method": "lora", "rank": 64},
571
- {"element": "all", "method": "none"}
572
- ]""",
573
  "huggingface"
574
  ],
575
  [
@@ -577,44 +808,44 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Mixed Precision Recovery") as
577
  "https://huggingface.co/Yabo/FramePainter/tree/main",
578
  "unet_diffusion_pytorch_model.safetensors",
579
  "e5m2",
580
- """[
581
- {"element": "attn", "method": "lora", "rank": 128},
582
- {"element": "transformer", "method": "lora", "rank": 96},
583
- {"element": "conv", "method": "diff"},
584
- {"element": "resnet", "method": "diff"},
585
- {"element": "all", "method": "none"}
586
- ]""",
587
  "huggingface"
588
  ]
589
  ],
590
- inputs=[source_type, repo_url, safetensors_filename, fp8_format, recovery_configs_json, target_type],
591
- label="Example Conversions"
 
592
  )
593
 
594
  gr.Markdown("""
595
- ## πŸ’‘ Precision Recovery Strategy Guide
 
 
 
 
 
 
 
596
 
597
- ### **LoRA Method** (best for attention/linear layers)
598
- - **Use for**: `text`, `attn`, `transformer`, `emb`, `mlp` layers
599
- - **Rank selection**:
600
- - Text encoders: 64-128
601
- - Attention blocks: 64-128
602
- - Other linear layers: 32-64
603
- - **Benefits**: Captures weight matrix structure, better for semantic understanding
604
- - **Limitations**: Only works on 2D tensors, not suitable for convolutions
605
 
606
- ### **Difference Method** (best for convolutional layers)
607
- - **Use for**: `vae`, `encoder`, `decoder`, `conv`, `resnet` layers
608
- - **How it works**: Stores the exact difference between FP8 and original weights
609
- - **Benefits**: Works with any tensor shape, more accurate for spatial features
610
- - **Limitations**: Larger file size than LoRA for equivalent quality
611
 
612
- ### **Rule Ordering Tips**
613
- - Put specific patterns first (`vae.encoder`), general patterns last (`all`)
614
- - Always end with an `{"element": "all", "method": "none"}` rule as fallback
615
- - Layer names are **case-insensitive** - use lowercase patterns for matching
616
 
617
- > **Pro Tip**: For diffusion models, use Difference for VAE/convolutional components and LoRA for text/attention components for optimal quality/size tradeoff.
618
  """)
619
 
620
  demo.launch()
 
9
  from safetensors.torch import load_file, save_file
10
  import torch
11
  import torch.nn.functional as F
12
+ import traceback
13
  try:
14
  from modelscope.hub.file_download import model_file_download as ms_file_download
15
  from modelscope.hub.api import HubApi as ModelScopeApi
 
18
  MODELScope_AVAILABLE = False
19
 
20
  def low_rank_decomposition(weight, rank=64):
21
+ """
22
+ Correct LoRA decomposition supporting 2D and 4D tensors.
23
+ Returns (lora_A, lora_B) such that weight β‰ˆ lora_B @ lora_A for 2D,
24
+ or appropriate conv form for 4D.
25
+ """
26
+ original_shape = weight.shape
27
+ original_dtype = weight.dtype
28
  try:
29
+ if weight.ndim == 2:
30
+ actual_rank = min(rank, min(weight.shape) // 2)
31
+ if actual_rank < 4:
32
+ return None, None
33
+ U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
34
+ S_sqrt = torch.sqrt(S[:actual_rank])
35
+ # Standard LoRA factorization: W β‰ˆ W_B @ W_A
36
+ W_A = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous() # [rank, in_features]
37
+ W_B = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous() # [out_features, rank]
38
+ return W_A.to(original_dtype), W_B.to(original_dtype)
39
+ elif weight.ndim == 4:
40
+ out_ch, in_ch, k_h, k_w = weight.shape
41
+ if k_h * k_w <= 9: # small conv kernels (e.g., 3x3)
42
+ # Reshape to 2D: [out_ch, in_ch * k_h * k_w]
43
+ weight_2d = weight.view(out_ch, -1)
44
+ actual_rank = min(rank, min(weight_2d.shape) // 2)
45
+ if actual_rank < 4:
46
+ return None, None
47
+ U, S, Vh = torch.linalg.svd(weight_2d.float(), full_matrices=False)
48
+ S_sqrt = torch.sqrt(S[:actual_rank])
49
+ W_A_2d = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous()
50
+ W_B_2d = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous()
51
+ # Reshape back to conv format
52
+ W_A = W_A_2d.view(actual_rank, in_ch, k_h, k_w).contiguous()
53
+ W_B = W_B_2d.view(out_ch, actual_rank, 1, 1).contiguous()
54
+ return W_A.to(original_dtype), W_B.to(original_dtype)
55
+ return None, None
56
  except Exception as e:
57
+ print(f"Decomposition error for {original_shape}: {e}")
58
+ traceback.print_exc()
59
  return None, None
60
 
61
  def extract_correction_factors(original_weight, fp8_weight):
 
92
  else:
93
  return error.mean().to(original_weight.dtype)
94
 
95
+ def get_tensor_info(tensor):
96
+ """Get detailed tensor information for pattern matching."""
97
+ shape = list(tensor.shape)
98
+ dim = tensor.dim()
99
+ numel = tensor.numel()
100
+ dtype = str(tensor.dtype)
101
+
102
+ # Determine tensor type based on shape
103
+ tensor_type = "other"
104
+ if dim == 4 and shape[2] == shape[3]: # Convolutional layer with square kernel
105
+ tensor_type = "conv"
106
+ elif dim == 2:
107
+ if shape[0] > shape[1] * 4: # More likely to be output projection
108
+ tensor_type = "output_proj"
109
+ elif shape[1] > shape[0] * 4: # More likely to be input projection
110
+ tensor_type = "input_proj"
111
+ else:
112
+ tensor_type = "linear"
113
+ elif dim == 1:
114
+ tensor_type = "bias"
115
+
116
+ return {
117
+ "shape": shape,
118
+ "dim": dim,
119
+ "numel": numel,
120
+ "type": tensor_type,
121
+ "dtype": dtype
122
  }
123
+
124
+ def matches_pattern(key, tensor_info, pattern):
125
+ """Check if a tensor matches a pattern definition."""
126
+ key_lower = key.lower()
127
+
128
+ # Match by key name pattern
129
+ if "key_pattern" in pattern:
130
+ key_pattern = pattern["key_pattern"].lower()
131
+ if key_pattern != "all" and key_pattern not in key_lower:
132
+ return False
133
+
134
+ # Match by tensor dimension
135
+ if "dim" in pattern and tensor_info["dim"] != pattern["dim"]:
136
+ return False
137
 
138
+ # Match by tensor type
139
+ if "type" in pattern and tensor_info["type"] != pattern["type"]:
140
+ return False
 
 
141
 
142
+ # Match by minimum tensor size
143
+ if "min_size" in pattern and tensor_info["numel"] < pattern["min_size"]:
144
+ return False
 
145
 
146
+ # Match by shape constraints
147
+ if "shape_contains" in pattern:
148
+ shape_contains = pattern["shape_contains"]
149
+ if not any(shape_contains == dim for dim in tensor_info["shape"]):
150
+ return False
151
 
152
+ return True
153
 
154
  def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_format,
155
+ recovery_rules, progress=gr.Progress()):
156
+ """Convert model to FP8 with customizable per-tensor recovery strategies."""
157
  progress(0.1, desc="Starting FP8 conversion with precision recovery...")
158
  try:
159
  def read_safetensors_metadata(path):
 
170
  state_dict = load_file(safetensors_path)
171
  progress(0.3, desc="Loaded model weights.")
172
 
 
 
 
 
173
  # Setup FP8 format
174
  fp8_dtype = torch.float8_e5m2 if fp8_format == "e5m2" else torch.float8_e4m3fn
175
 
 
180
  "total_layers": len(state_dict),
181
  "processed_layers": 0,
182
  "skipped_layers": [],
183
+ "recovery_counts": {"lora": 0, "diff": 0},
184
+ "rule_matches": {i: 0 for i in range(len(recovery_rules))}
185
  }
186
 
 
 
 
 
 
 
 
 
 
 
187
  # Process each tensor
188
  total = len(state_dict)
189
  for i, key in enumerate(state_dict):
190
  progress(0.3 + 0.5 * (i / total), desc=f"Processing {i+1}/{total}: {key.split('.')[-1]}")
191
  weight = state_dict[key]
192
+ tensor_info = get_tensor_info(weight)
193
 
 
194
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
195
  fp8_weight = weight.to(fp8_dtype)
196
  sd_fp8[key] = fp8_weight
 
199
  stats["skipped_layers"].append(f"{key}: non-float dtype")
200
  continue
201
 
202
+ # Find matching rule for this tensor
203
+ recovery_applied = False
204
+ matched_rule_index = -1
 
 
205
 
206
+ for rule_idx, rule in enumerate(recovery_rules):
207
+ if matches_pattern(key, tensor_info, rule):
208
+ matched_rule_index = rule_idx
209
+ recovery_method = rule["method"]
210
+
211
+ try:
212
+ if recovery_method == "lora" and weight.ndim == 2:
213
+ # LoRA recovery for 2D tensors only
214
+ rank = rule.get("rank", 64)
215
+ # Adjust rank for smaller matrices
216
+ adjusted_rank = min(rank, min(weight.shape) // 2)
217
+ if adjusted_rank >= 4:
218
+ A, B = low_rank_decomposition(weight, rank=adjusted_rank)
219
+ if A is not None and B is not None:
220
+ recovery_weights[f"lora_A.{key}"] = A
221
+ recovery_weights[f"lora_B.{key}"] = B
222
+ stats["processed_layers"] += 1
223
+ stats["recovery_counts"]["lora"] += 1
224
+ stats["rule_matches"][rule_idx] += 1
225
+ recovery_applied = True
226
+ break
227
+
228
+ elif recovery_method == "diff":
229
+ # Difference/correction recovery for any tensor type
230
+ corr = extract_correction_factors(weight, fp8_weight)
231
+ if corr is not None:
232
+ recovery_weights[f"diff.{key}"] = corr
233
+ stats["processed_layers"] += 1
234
+ stats["recovery_counts"]["diff"] += 1
235
+ stats["rule_matches"][rule_idx] += 1
236
+ recovery_applied = True
237
+ break
238
+
239
+ # If method is "none" or recovery failed, continue to next rule
240
+ if recovery_method == "none":
241
+ break
242
+
243
+ except Exception as e:
244
+ stats["skipped_layers"].append(f"{key}: error with rule {rule_idx} - {str(e)}")
245
 
246
+ if not recovery_applied:
247
+ reason = "no matching rule" if matched_rule_index == -1 else f"recovery failed with rule {matched_rule_index}"
248
+ stats["skipped_layers"].append(f"{key}: {reason}")
249
 
250
  # Save FP8 model
251
  base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
 
259
  recovery_metadata = {
260
  "format": "pt",
261
  "fp8_format": fp8_format,
262
+ "recovery_rules": json.dumps(recovery_rules),
263
  "stats": json.dumps(stats)
264
  }
265
  save_file(recovery_weights, recovery_path, metadata=recovery_metadata)
 
273
  stats_msg += f" - LoRA recovery: {stats['recovery_counts']['lora']}\n"
274
  stats_msg += f" - Difference recovery: {stats['recovery_counts']['diff']}\n"
275
 
276
+ # Show rule effectiveness
277
+ stats_msg += "\nRule effectiveness:\n"
278
+ for rule_idx, rule in enumerate(recovery_rules):
279
+ matches = stats["rule_matches"][rule_idx]
280
+ if matches > 0:
281
+ method = rule["method"]
282
+ pattern = rule.get("key_pattern", "no pattern")
283
+ rank_info = f" (rank {rule.get('rank', 'N/A')})" if method == "lora" else ""
284
+ stats_msg += f"- Rule {rule_idx}: {matches} layers matched pattern '{pattern}' with {method}{rank_info}\n"
285
+
286
  if not recovery_weights:
287
  stats_msg += "\n⚠️ No recovery weights were generated. All layers use pure FP8."
288
 
 
290
  return True, stats_msg, stats, fp8_path, recovery_path
291
 
292
  except Exception as e:
293
+ traceback.print_exc()
294
+ return False, str(e), None, None, None
 
295
 
296
  def parse_hf_url(url):
297
  url = url.strip().rstrip("/")
 
349
  else:
350
  raise ValueError("Unknown target")
351
 
352
+ def generate_default_rules(architecture="auto"):
353
+ """Generate default recovery rules based on architecture."""
354
+ if architecture == "vae":
355
+ return """[
356
+ {
357
+ "key_pattern": "vae",
358
+ "dim": 4,
359
+ "method": "diff"
360
+ },
361
+ {
362
+ "key_pattern": "encoder",
363
+ "dim": 4,
364
+ "method": "diff"
365
+ },
366
+ {
367
+ "key_pattern": "decoder",
368
+ "dim": 4,
369
+ "method": "diff"
370
+ },
371
+ {
372
+ "key_pattern": "all",
373
+ "method": "none"
374
+ }
375
+ ]"""
376
+ elif architecture == "text_encoder":
377
+ return """[
378
+ {
379
+ "key_pattern": "text",
380
+ "dim": 2,
381
+ "min_size": 10000,
382
+ "method": "lora",
383
+ "rank": 64
384
+ },
385
+ {
386
+ "key_pattern": "emb",
387
+ "dim": 2,
388
+ "min_size": 10000,
389
+ "method": "lora",
390
+ "rank": 64
391
+ },
392
+ {
393
+ "key_pattern": "attn",
394
+ "dim": 2,
395
+ "min_size": 10000,
396
+ "method": "lora",
397
+ "rank": 128
398
+ },
399
+ {
400
+ "key_pattern": "all",
401
+ "method": "none"
402
+ }
403
+ ]"""
404
+ elif architecture == "unet_transformer":
405
+ return """[
406
+ {
407
+ "key_pattern": "attn",
408
+ "dim": 2,
409
+ "min_size": 10000,
410
+ "method": "lora",
411
+ "rank": 128
412
+ },
413
+ {
414
+ "key_pattern": "transformer",
415
+ "dim": 2,
416
+ "min_size": 10000,
417
+ "method": "lora",
418
+ "rank": 96
419
+ },
420
+ {
421
+ "key_pattern": "all",
422
+ "method": "none"
423
+ }
424
+ ]"""
425
+ elif architecture == "unet_conv":
426
+ return """[
427
+ {
428
+ "key_pattern": "conv",
429
+ "dim": 4,
430
+ "method": "diff"
431
+ },
432
+ {
433
+ "key_pattern": "resnet",
434
+ "dim": 4,
435
+ "method": "diff"
436
+ },
437
+ {
438
+ "key_pattern": "down",
439
+ "dim": 4,
440
+ "method": "diff"
441
+ },
442
+ {
443
+ "key_pattern": "up",
444
+ "dim": 4,
445
+ "method": "diff"
446
+ },
447
+ {
448
+ "key_pattern": "all",
449
+ "method": "none"
450
+ }
451
+ ]"""
452
+ else: # "all" or "auto"
453
+ return """[
454
+ {
455
+ "key_pattern": "vae",
456
+ "dim": 4,
457
+ "method": "diff"
458
+ },
459
+ {
460
+ "key_pattern": "encoder",
461
+ "dim": 4,
462
+ "method": "diff"
463
+ },
464
+ {
465
+ "key_pattern": "decoder",
466
+ "dim": 4,
467
+ "method": "diff"
468
+ },
469
+ {
470
+ "key_pattern": "text",
471
+ "dim": 2,
472
+ "min_size": 10000,
473
+ "method": "lora",
474
+ "rank": 64
475
+ },
476
+ {
477
+ "key_pattern": "emb",
478
+ "dim": 2,
479
+ "min_size": 10000,
480
+ "method": "lora",
481
+ "rank": 64
482
+ },
483
+ {
484
+ "key_pattern": "attn",
485
+ "dim": 2,
486
+ "min_size": 10000,
487
+ "method": "lora",
488
+ "rank": 128
489
+ },
490
+ {
491
+ "key_pattern": "conv",
492
+ "dim": 4,
493
+ "method": "diff"
494
+ },
495
+ {
496
+ "key_pattern": "resnet",
497
+ "dim": 4,
498
+ "method": "diff"
499
+ },
500
+ {
501
+ "key_pattern": "all",
502
+ "method": "none"
503
+ }
504
+ ]"""
505
+
506
  def process_and_upload_fp8(
507
  source_type,
508
  repo_url,
509
  safetensors_filename,
510
  fp8_format,
511
+ recovery_rules_json,
512
  target_type,
513
  new_repo_id,
514
  hf_token,
 
523
  if target_type == "huggingface" and not hf_token:
524
  return None, "❌ Hugging Face token required for target.", "", ""
525
 
526
+ # Parse recovery rules
527
  try:
528
+ recovery_rules = json.loads(recovery_rules_json)
529
  except json.JSONDecodeError:
530
+ return None, "❌ Invalid recovery rules JSON.", "", ""
531
 
532
+ # Validate rules
533
  valid_methods = ["none", "lora", "diff"]
534
+ for rule in recovery_rules:
535
+ if "method" not in rule or rule["method"] not in valid_methods:
536
+ return None, f"❌ Invalid method in rule. Use 'none', 'lora', or 'diff'", "", ""
537
+ if rule["method"] == "lora" and "rank" not in rule:
 
 
538
  return None, "❌ LoRA method requires 'rank' parameter", "", ""
539
 
540
  temp_dir = None
 
547
 
548
  progress(0.2, desc="Converting to FP8 with precision recovery...")
549
  success, msg, stats, fp8_path, recovery_path = convert_safetensors_to_fp8_with_recovery(
550
+ safetensors_path, output_dir, fp8_format, recovery_rules, progress
551
  )
552
 
553
  if not success:
 
572
  - mixed-method
573
  - converted-by-gradio
574
  ---
575
+ # FP8 Model with Per-Tensor Precision Recovery
576
  - **Source**: `{repo_url}`
577
  - **Original File**: `{safetensors_filename}`
578
  - **FP8 Format**: `{fp8_format.upper()}`
579
  - **FP8 File**: `{fp8_filename}`
580
  - **Recovery File**: `{recovery_filename if recovery_filename else "None"}`
581
 
582
+ ## Recovery Rules Used
583
  ```json
584
+ {json.dumps(recovery_rules, indent=2)}
585
  ```
586
 
587
  ## Usage (Inference)
 
601
  fp8_weight = fp8_state[key].to(torch.float32) # Convert to float32 for computation
602
 
603
  # Apply LoRA recovery if available
604
+ lora_a_key = f"lora_A.{{key}}"
605
+ lora_b_key = f"lora_B.{{key}}"
606
+ if lora_a_key in recovery_state and lora_b_key in recovery_state:
607
+ A = recovery_state[lora_a_key].to(torch.float32)
608
+ B = recovery_state[lora_b_key].to(torch.float32)
609
  # Reconstruct the low-rank approximation
610
  lora_weight = B @ A
611
  fp8_weight = fp8_weight + lora_weight
612
 
613
  # Apply difference recovery if available
614
+ diff_key = f"diff.{{key}}"
615
+ if diff_key in recovery_state:
616
+ diff = recovery_state[diff_key].to(torch.float32)
617
  fp8_weight = fp8_weight + diff
618
 
619
  reconstructed[key] = fp8_weight
 
666
  recovery_details)
667
 
668
  except Exception as e:
669
+ traceback.print_exc()
670
+ return None, f"❌ Error: {str(e)}", "", ""
 
671
 
672
  finally:
673
  if temp_dir:
674
  shutil.rmtree(temp_dir, ignore_errors=True)
675
  shutil.rmtree(output_dir, ignore_errors=True)
676
 
677
+ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery") as demo:
678
+ gr.Markdown("# πŸ”„ Advanced FP8 Quantizer with Per-Tensor Precision Recovery")
679
+ gr.Markdown("Convert `.safetensors` β†’ **FP8** + **customizable precision recovery**. Full control over LoRA and difference methods per tensor pattern.")
680
 
681
  with gr.Row():
682
  with gr.Column():
 
687
  with gr.Accordion("FP8 Settings", open=True):
688
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
689
 
690
+ with gr.Accordion("Per-Tensor Recovery Rules", open=True):
691
  gr.Markdown("""
692
+ ### Configure recovery strategy for each tensor pattern
693
 
694
+ Format: JSON array of rule objects:
695
  ```json
696
  [
697
+ {
698
+ "key_pattern": "vae",
699
+ "dim": 4,
700
+ "method": "diff"
701
+ },
702
+ {
703
+ "key_pattern": "attn",
704
+ "dim": 2,
705
+ "min_size": 10000,
706
+ "method": "lora",
707
+ "rank": 64
708
+ },
709
+ {
710
+ "key_pattern": "all",
711
+ "method": "none"
712
+ }
713
  ]
714
  ```
715
 
716
+ ### Rule Fields (all optional except "method"):
717
+ - `key_pattern`: Substring to match in weight keys (case-insensitive). Use "all" to match everything.
718
+ - `dim`: Tensor dimension (e.g., 2 for linear layers, 4 for convolutions)
719
+ - `type`: Tensor type ("conv", "linear", "bias", "input_proj", "output_proj")
720
+ - `min_size`: Minimum number of elements in tensor
721
+ - `shape_contains`: Specific dimension size that must be present in shape
722
  - `method`: "none" (pure FP8), "lora" (low-rank adaptation), or "diff" (difference/correction)
723
+ - `rank`: Required for "lora" method (higher = better quality but larger file)
724
 
725
+ **Rules are applied in order** - first match wins. Always end with a catch-all rule.
726
  """)
727
 
728
+ recovery_rules_json = gr.Textbox(
729
+ value=generate_default_rules("all"),
730
+ lines=15,
731
+ label="Recovery Rules (JSON)",
 
 
 
 
 
 
 
 
732
  interactive=True
733
  )
734
+
735
+ architecture_preset = gr.Dropdown(
736
+ choices=[
737
+ ("Auto-detect architecture", "auto"),
738
+ ("VAE (Difference method)", "vae"),
739
+ ("Text Encoder (LoRA)", "text_encoder"),
740
+ ("UNet Transformers (LoRA)", "unet_transformer"),
741
+ ("UNet Convolutions (Difference)", "unet_conv"),
742
+ ("All Components (Mixed)", "all")
743
+ ],
744
+ value="auto",
745
+ label="Architecture Preset"
746
+ )
747
+
748
+ architecture_preset.change(
749
+ fn=generate_default_rules,
750
+ inputs=architecture_preset,
751
+ outputs=recovery_rules_json
752
+ )
753
 
754
  with gr.Accordion("Authentication", open=False):
755
  hf_token = gr.Textbox(label="Hugging Face Token", type="password")
 
774
  repo_url,
775
  safetensors_filename,
776
  fp8_format,
777
+ recovery_rules_json,
778
  target_type,
779
  new_repo_id,
780
  hf_token,
 
792
  "https://huggingface.co/stabilityai/sdxl-vae",
793
  "diffusion_pytorch_model.safetensors",
794
  "e4m3fn",
795
+ generate_default_rules("vae"),
 
 
 
 
 
796
  "huggingface"
797
  ],
798
  [
 
800
  "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
801
  "model.safetensors",
802
  "e5m2",
803
+ generate_default_rules("text_encoder"),
 
 
 
 
804
  "huggingface"
805
  ],
806
  [
 
808
  "https://huggingface.co/Yabo/FramePainter/tree/main",
809
  "unet_diffusion_pytorch_model.safetensors",
810
  "e5m2",
811
+ generate_default_rules("unet_transformer"),
 
 
 
 
 
 
812
  "huggingface"
813
  ]
814
  ],
815
+ inputs=[source_type, repo_url, safetensors_filename, fp8_format, recovery_rules_json, target_type],
816
+ label="Example Conversions",
817
+ cache_examples=False
818
  )
819
 
820
  gr.Markdown("""
821
+ ## πŸ’‘ Tensor Pattern Matching Guide
822
+
823
+ This tool uses **advanced tensor pattern matching** to determine which recovery method to apply to each layer:
824
+
825
+ ### **Key Patterns**
826
+ - Match by substring in weight key name
827
+ - Case-insensitive matching
828
+ - Special keyword "all" matches everything
829
 
830
+ ### **Tensor Properties**
831
+ - **Dimension (dim)**: Use `dim: 2` for linear layers, `dim: 4` for convolutions
832
+ - **Type**: Automatic classification based on shape:
833
+ - `conv`: 4D tensors with equal spatial dimensions
834
+ - `linear`: 2D tensors without extreme aspect ratio
835
+ - `input_proj`: 2D tensors with much larger second dimension
836
+ - `output_proj`: 2D tensors with much larger first dimension
837
+ - `bias`: 1D tensors
838
 
839
+ ### **Size Constraints**
840
+ - **min_size**: Only apply to tensors with at least N elements
841
+ - **shape_contains**: Match tensors containing a specific dimension size
 
 
842
 
843
+ ### **Rule Processing**
844
+ - Rules are evaluated **in order**
845
+ - First matching rule wins
846
+ - Always include a catch-all rule at the end
847
 
848
+ > **Pro Tip for VAE**: Use `"dim": 4` combined with `"key_pattern": "vae"` to reliably target VAE convolutional layers with difference recovery.
849
  """)
850
 
851
  demo.launch()