codemichaeld commited on
Commit
840bb85
Β·
verified Β·
1 Parent(s): 9de10a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +387 -52
app.py CHANGED
@@ -5,11 +5,14 @@ import shutil
5
  import re
6
  import json
7
  from pathlib import Path
8
- from huggingface_hub import HfApi, hf_hub_download
9
  from safetensors.torch import load_file, save_file
10
  import torch
11
  import torch.nn.functional as F
12
  import traceback
 
 
 
13
  try:
14
  from modelscope.hub.file_download import model_file_download as ms_file_download
15
  from modelscope.hub.api import HubApi as ModelScopeApi
@@ -151,24 +154,145 @@ def matches_pattern(key, tensor_info, pattern):
151
 
152
  return True
153
 
154
- def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_format,
155
- recovery_rules, progress=gr.Progress()):
156
- """Convert model to FP8 with customizable per-tensor recovery strategies."""
157
- progress(0.1, desc="Starting FP8 conversion with precision recovery...")
158
- try:
159
- def read_safetensors_metadata(path):
160
- with open(path, 'rb') as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  header_size = int.from_bytes(f.read(8), 'little')
162
  header_json = f.read(header_size).decode('utf-8')
163
  header = json.loads(header_json)
164
- return header.get('__metadata__', {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- metadata = read_safetensors_metadata(safetensors_path)
167
- progress(0.2, desc="Loaded metadata.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- # Load model
170
- state_dict = load_file(safetensors_path)
171
- progress(0.3, desc="Loaded model weights.")
 
 
 
 
172
 
173
  # Setup FP8 format
174
  fp8_dtype = torch.float8_e5m2 if fp8_format == "e5m2" else torch.float8_e4m3fn
@@ -234,7 +358,7 @@ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_f
234
  stats["recovery_counts"]["diff"] += 1
235
  stats["rule_matches"][rule_idx] += 1
236
  recovery_applied = True
237
- break
238
 
239
  # If method is "none" or recovery failed, continue to next rule
240
  if recovery_method == "none":
@@ -247,17 +371,19 @@ def convert_safetensors_to_fp8_with_recovery(safetensors_path, output_dir, fp8_f
247
  reason = "no matching rule" if matched_rule_index == -1 else f"recovery failed with rule {matched_rule_index}"
248
  stats["skipped_layers"].append(f"{key}: {reason}")
249
 
 
 
 
250
  # Save FP8 model
251
- base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
252
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
253
- save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
254
 
255
  # Save recovery weights if any were generated
256
  recovery_path = None
257
  if recovery_weights:
258
  recovery_path = os.path.join(output_dir, f"{base_name}-recovery.safetensors")
259
  recovery_metadata = {
260
- "format": "pt",
261
  "fp8_format": fp8_format,
262
  "recovery_rules": json.dumps(recovery_rules),
263
  "stats": json.dumps(stats)
@@ -309,27 +435,176 @@ def parse_hf_url(url):
309
  subfolder = "/".join(parts[2:])
310
  return repo_id, subfolder
311
 
312
- def download_safetensors_file(source_type, repo_url, filename, hf_token=None, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  temp_dir = tempfile.mkdtemp()
314
  try:
315
  if source_type == "huggingface":
316
  repo_id, subfolder = parse_hf_url(repo_url)
317
- safetensors_path = hf_hub_download(
318
- repo_id=repo_id,
319
- filename=filename,
320
- subfolder=subfolder or None,
321
- cache_dir=temp_dir,
322
- token=hf_token,
323
- resume_download=True
324
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  elif source_type == "modelscope":
326
  if not MODELScope_AVAILABLE:
327
  raise ImportError("ModelScope not installed")
328
  repo_id = repo_url.strip()
329
- safetensors_path = ms_file_download(model_id=repo_id, file_path=filename)
 
 
 
 
 
 
 
 
330
  else:
331
  raise ValueError("Unknown source")
332
- return safetensors_path, temp_dir
333
  except Exception as e:
334
  shutil.rmtree(temp_dir, ignore_errors=True)
335
  raise e
@@ -506,7 +781,8 @@ def generate_default_rules(architecture="auto"):
506
  def process_and_upload_fp8(
507
  source_type,
508
  repo_url,
509
- safetensors_filename,
 
510
  fp8_format,
511
  recovery_rules_json,
512
  target_type,
@@ -541,13 +817,13 @@ def process_and_upload_fp8(
541
  output_dir = tempfile.mkdtemp()
542
  try:
543
  progress(0.05, desc="Downloading model...")
544
- safetensors_path, temp_dir = download_safetensors_file(
545
- source_type, repo_url, safetensors_filename, hf_token, progress
546
  )
547
 
548
- progress(0.2, desc="Converting to FP8 with precision recovery...")
549
- success, msg, stats, fp8_path, recovery_path = convert_safetensors_to_fp8_with_recovery(
550
- safetensors_path, output_dir, fp8_format, recovery_rules, progress
551
  )
552
 
553
  if not success:
@@ -559,7 +835,14 @@ def process_and_upload_fp8(
559
  )
560
 
561
  # Generate README
562
- base_name = os.path.splitext(safetensors_filename)[0]
 
 
 
 
 
 
 
563
  fp8_filename = os.path.basename(fp8_path)
564
  recovery_filename = os.path.basename(recovery_path) if recovery_path else ""
565
 
@@ -574,27 +857,23 @@ tags:
574
  ---
575
  # FP8 Model with Per-Tensor Precision Recovery
576
  - **Source**: `{repo_url}`
577
- - **Original File**: `{safetensors_filename}`
 
578
  - **FP8 Format**: `{fp8_format.upper()}`
579
  - **FP8 File**: `{fp8_filename}`
580
  - **Recovery File**: `{recovery_filename if recovery_filename else "None"}`
581
-
582
  ## Recovery Rules Used
583
  ```json
584
  {json.dumps(recovery_rules, indent=2)}
585
  ```
586
-
587
  ## Usage (Inference)
588
  ```python
589
  from safetensors.torch import load_file
590
  import torch
591
-
592
  # Load FP8 model
593
  fp8_state = load_file("{fp8_filename}")
594
-
595
  # Load recovery weights if available
596
  recovery_state = load_file("{recovery_filename}") if "{recovery_filename}" and os.path.exists("{recovery_filename}") else {{}}
597
-
598
  # Reconstruct high-precision weights
599
  reconstructed = {{}}
600
  for key in fp8_state:
@@ -617,14 +896,11 @@ for key in fp8_state:
617
  fp8_weight = fp8_weight + diff
618
 
619
  reconstructed[key] = fp8_weight
620
-
621
  # Use reconstructed weights in your model
622
  model.load_state_dict(reconstructed)
623
  ```
624
-
625
  > **Note**: For best results, use the same recovery configuration during inference as was used during extraction.
626
  > Requires PyTorch β‰₯ 2.1 for FP8 support.
627
-
628
  ## Statistics
629
  - **Total layers**: {stats['total_layers']}
630
  - **Layers with recovery**: {stats['processed_layers']}
@@ -676,13 +952,24 @@ Includes:
676
 
677
  with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery") as demo:
678
  gr.Markdown("# πŸ”„ Advanced FP8 Quantizer with Per-Tensor Precision Recovery")
679
- gr.Markdown("Convert `.safetensors` β†’ **FP8** + **customizable precision recovery**. Full control over LoRA and difference methods per tensor pattern.")
680
 
681
  with gr.Row():
682
  with gr.Column():
683
  source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source")
684
  repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
685
- safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors")
 
 
 
 
 
 
 
 
 
 
 
686
 
687
  with gr.Accordion("FP8 Settings", open=True):
688
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
@@ -772,7 +1059,8 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
772
  inputs=[
773
  source_type,
774
  repo_url,
775
- safetensors_filename,
 
776
  fp8_format,
777
  recovery_rules_json,
778
  target_type,
@@ -790,7 +1078,8 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
790
  [
791
  "huggingface",
792
  "https://huggingface.co/stabilityai/sdxl-vae",
793
- "diffusion_pytorch_model.safetensors",
 
794
  "e4m3fn",
795
  generate_default_rules("vae"),
796
  "huggingface"
@@ -798,7 +1087,8 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
798
  [
799
  "huggingface",
800
  "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
801
- "model.safetensors",
 
802
  "e5m2",
803
  generate_default_rules("text_encoder"),
804
  "huggingface"
@@ -806,13 +1096,32 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
806
  [
807
  "huggingface",
808
  "https://huggingface.co/Yabo/FramePainter/tree/main",
809
- "unet_diffusion_pytorch_model.safetensors",
 
810
  "e5m2",
811
  generate_default_rules("unet_transformer"),
812
  "huggingface"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  ]
814
  ],
815
- inputs=[source_type, repo_url, safetensors_filename, fp8_format, recovery_rules_json, target_type],
816
  label="Example Conversions",
817
  cache_examples=False
818
  )
@@ -846,6 +1155,32 @@ with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery"
846
  - Always include a catch-all rule at the end
847
 
848
  > **Pro Tip for VAE**: Use `"dim": 4` combined with `"key_pattern": "vae"` to reliably target VAE convolutional layers with difference recovery.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
849
  """)
850
 
851
  demo.launch()
 
5
  import re
6
  import json
7
  from pathlib import Path
8
+ from huggingface_hub import HfApi, hf_hub_download, snapshot_download, list_repo_files
9
  from safetensors.torch import load_file, save_file
10
  import torch
11
  import torch.nn.functional as F
12
  import traceback
13
+ import glob
14
+ import time
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
  try:
17
  from modelscope.hub.file_download import model_file_download as ms_file_download
18
  from modelscope.hub.api import HubApi as ModelScopeApi
 
154
 
155
  return True
156
 
157
+ def load_model_files(model_paths, model_format="safetensors", progress_callback=None):
158
+ """
159
+ Load model weights from one or more files, supporting sharded safetensors and other formats.
160
+ """
161
+ state_dict = {}
162
+
163
+ if model_format == "safetensors":
164
+ # Handle sharded safetensors files
165
+ for i, path in enumerate(model_paths):
166
+ if progress_callback:
167
+ progress_callback(f"Loading shard {i+1}/{len(model_paths)}: {os.path.basename(path)}")
168
+ part_dict = load_file(path)
169
+ state_dict.update(part_dict)
170
+ elif model_format in ["pth", "pt"]:
171
+ # PyTorch checkpoint files
172
+ for i, path in enumerate(model_paths):
173
+ if progress_callback:
174
+ progress_callback(f"Loading checkpoint {i+1}/{len(model_paths)}: {os.path.basename(path)}")
175
+ checkpoint = torch.load(path, map_location="cpu")
176
+ if isinstance(checkpoint, dict):
177
+ # Try to extract state dict from checkpoint
178
+ if "state_dict" in checkpoint:
179
+ state_dict.update(checkpoint["state_dict"])
180
+ elif "model_state_dict" in checkpoint:
181
+ state_dict.update(checkpoint["model_state_dict"])
182
+ elif "model" in checkpoint and isinstance(checkpoint["model"], dict):
183
+ state_dict.update(checkpoint["model"])
184
+ else:
185
+ # Assume the checkpoint itself is the state dict
186
+ state_dict.update(checkpoint)
187
+ elif model_format == "ckpt":
188
+ # Checkpoint files (similar to pth)
189
+ for i, path in enumerate(model_paths):
190
+ if progress_callback:
191
+ progress_callback(f"Loading checkpoint {i+1}/{len(model_paths)}: {os.path.basename(path)}")
192
+ checkpoint = torch.load(path, map_location="cpu")
193
+ if isinstance(checkpoint, dict):
194
+ if "state_dict" in checkpoint:
195
+ state_dict.update(checkpoint["state_dict"])
196
+ elif "model_state_dict" in checkpoint:
197
+ state_dict.update(checkpoint["model_state_dict"])
198
+ elif "model" in checkpoint and isinstance(checkpoint["model"], dict):
199
+ state_dict.update(checkpoint["model"])
200
+ else:
201
+ state_dict.update(checkpoint)
202
+
203
+ return state_dict
204
+
205
+ def read_model_metadata(model_paths, model_format="safetensors"):
206
+ """Read metadata from model files."""
207
+ metadata = {}
208
+
209
+ if model_format == "safetensors":
210
+ # Read metadata from the first safetensors file
211
+ if model_paths:
212
+ with open(model_paths[0], 'rb') as f:
213
  header_size = int.from_bytes(f.read(8), 'little')
214
  header_json = f.read(header_size).decode('utf-8')
215
  header = json.loads(header_json)
216
+ metadata = header.get('__metadata__', {})
217
+ elif model_format in ["pth", "pt", "ckpt"]:
218
+ # Try to extract metadata from checkpoint files
219
+ if model_paths:
220
+ checkpoint = torch.load(model_paths[0], map_location="cpu")
221
+ if isinstance(checkpoint, dict):
222
+ # Look for common metadata keys
223
+ for key in ["hyperparameters", "args", "config", "metadata"]:
224
+ if key in checkpoint:
225
+ metadata[key] = checkpoint[key]
226
+
227
+ return metadata
228
+
229
+ def extract_base_name_from_sharded_files(model_paths):
230
+ """Extract a common base name from sharded files."""
231
+ if not model_paths:
232
+ return "model"
233
+
234
+ if len(model_paths) == 1:
235
+ # Single file case
236
+ base_name = os.path.splitext(os.path.basename(model_paths[0]))[0]
237
+ # Remove common suffixes
238
+ for suffix in ["-fp8", "-fp16", "-bf16", "-32", "-16"]:
239
+ if base_name.endswith(suffix):
240
+ base_name = base_name[:-len(suffix)]
241
+ return base_name
242
+
243
+ # Multiple files case - find common prefix
244
+ base_names = [os.path.splitext(os.path.basename(p))[0] for p in model_paths]
245
+
246
+ # Handle Hugging Face pattern: model-00001-of-00002.safetensors
247
+ # Extract the part before the shard numbering
248
+ if all("-of-" in name for name in base_names):
249
+ # All files follow the "model-XXXXX-of-YYYYY" pattern
250
+ common_parts = []
251
+ for name in base_names:
252
+ # Split at the shard numbering
253
+ parts = name.split("-")
254
+ if len(parts) >= 3 and parts[-2].isdigit() and parts[-1].startswith("of"):
255
+ # Remove the last two parts (shard number and total)
256
+ common_part = "-".join(parts[:-2])
257
+ common_parts.append(common_part)
258
+ else:
259
+ common_parts.append(name)
260
 
261
+ # Use the most common base name
262
+ from collections import Counter
263
+ base_name = Counter(common_parts).most_common(1)[0][0]
264
+ return base_name
265
+
266
+ # Fallback: find common prefix
267
+ common_prefix = ""
268
+ for chars in zip(*base_names):
269
+ if len(set(chars)) == 1:
270
+ common_prefix += chars[0]
271
+ else:
272
+ break
273
+
274
+ # Clean up the common prefix
275
+ base_name = re.sub(r'[-_]+$', '', common_prefix)
276
+ if not base_name:
277
+ base_name = "model"
278
+
279
+ return base_name
280
+
281
+ def convert_model_to_fp8_with_recovery(model_paths, output_dir, fp8_format, recovery_rules,
282
+ model_format="safetensors", progress=gr.Progress()):
283
+ """Convert model to FP8 with customizable per-tensor recovery strategies."""
284
+ progress(0.05, desc=f"Starting FP8 conversion with precision recovery for {model_format}...")
285
+ try:
286
+ metadata = read_model_metadata(model_paths, model_format)
287
+ progress(0.1, desc="Loaded metadata.")
288
 
289
+ # Load model with progress tracking
290
+ state_dict = load_model_files(
291
+ model_paths,
292
+ model_format,
293
+ progress_callback=lambda msg: progress(0.15, desc=msg)
294
+ )
295
+ progress(0.25, desc=f"Loaded {len(model_paths)} model files with {len(state_dict)} tensors.")
296
 
297
  # Setup FP8 format
298
  fp8_dtype = torch.float8_e5m2 if fp8_format == "e5m2" else torch.float8_e4m3fn
 
358
  stats["recovery_counts"]["diff"] += 1
359
  stats["rule_matches"][rule_idx] += 1
360
  recovery_applied = True
361
+ break
362
 
363
  # If method is "none" or recovery failed, continue to next rule
364
  if recovery_method == "none":
 
371
  reason = "no matching rule" if matched_rule_index == -1 else f"recovery failed with rule {matched_rule_index}"
372
  stats["skipped_layers"].append(f"{key}: {reason}")
373
 
374
+ # Extract base name for output files
375
+ base_name = extract_base_name_from_sharded_files(model_paths)
376
+
377
  # Save FP8 model
 
378
  fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
379
+ save_file(sd_fp8, fp8_path, metadata={"format": model_format, "fp8_format": fp8_format, **metadata})
380
 
381
  # Save recovery weights if any were generated
382
  recovery_path = None
383
  if recovery_weights:
384
  recovery_path = os.path.join(output_dir, f"{base_name}-recovery.safetensors")
385
  recovery_metadata = {
386
+ "format": model_format,
387
  "fp8_format": fp8_format,
388
  "recovery_rules": json.dumps(recovery_rules),
389
  "stats": json.dumps(stats)
 
435
  subfolder = "/".join(parts[2:])
436
  return repo_id, subfolder
437
 
438
+ def download_single_file(args):
439
+ """Helper function for parallel downloads."""
440
+ repo_id, filename, subfolder, cache_dir, token = args
441
+ try:
442
+ path = hf_hub_download(
443
+ repo_id=repo_id,
444
+ filename=filename,
445
+ subfolder=subfolder,
446
+ cache_dir=cache_dir,
447
+ token=token,
448
+ resume_download=True
449
+ )
450
+ return path, None
451
+ except Exception as e:
452
+ return None, str(e)
453
+
454
+ def find_sharded_safetensors_files(repo_id, subfolder=None, hf_token=None, max_shards=50):
455
+ """Find all sharded safetensors files in a repository."""
456
+ try:
457
+ # List all files in the repository
458
+ repo_files = list_repo_files(repo_id, repo_type="model", token=hf_token)
459
+
460
+ # Filter for safetensors files in the subfolder
461
+ if subfolder:
462
+ pattern = f"{subfolder}/"
463
+ safetensors_files = [f for f in repo_files if f.endswith('.safetensors') and f.startswith(pattern)]
464
+ # Remove subfolder prefix
465
+ safetensors_files = [f[len(pattern):] for f in safetensors_files]
466
+ else:
467
+ safetensors_files = [f for f in repo_files if f.endswith('.safetensors')]
468
+
469
+ # Check if files follow sharding pattern
470
+ sharded_files = []
471
+ single_files = []
472
+
473
+ for f in safetensors_files:
474
+ if "-of-" in f:
475
+ sharded_files.append(f)
476
+ else:
477
+ single_files.append(f)
478
+
479
+ # Return sharded files if found, otherwise single files
480
+ if sharded_files:
481
+ # Sort by shard number for consistent ordering
482
+ sharded_files.sort(key=lambda x: int(re.search(r'-(\d+)-of-', x).group(1)))
483
+ # Limit number of shards to prevent accidental downloads of huge models
484
+ if len(sharded_files) > max_shards:
485
+ raise ValueError(f"Too many shards found ({len(sharded_files)}). Maximum allowed is {max_shards}. "
486
+ f"Please specify a more specific pattern.")
487
+ return sharded_files
488
+ elif single_files:
489
+ return single_files
490
+ else:
491
+ return []
492
+
493
+ except Exception as e:
494
+ print(f"Error listing repository files: {e}")
495
+ return []
496
+
497
+ def download_model_files(source_type, repo_url, filename_pattern, model_format, hf_token=None, progress=gr.Progress()):
498
  temp_dir = tempfile.mkdtemp()
499
  try:
500
  if source_type == "huggingface":
501
  repo_id, subfolder = parse_hf_url(repo_url)
502
+
503
+ if model_format == "safetensors":
504
+ # Handle different patterns for safetensors
505
+ if filename_pattern == "auto" or filename_pattern == "":
506
+ # Auto-detect sharded files
507
+ progress(0.1, desc="Discovering model files...")
508
+ found_files = find_sharded_safetensors_files(repo_id, subfolder, hf_token)
509
+ if not found_files:
510
+ raise ValueError("No safetensors files found in repository")
511
+
512
+ progress(0.2, desc=f"Found {len(found_files)} shard(s). Downloading...")
513
+
514
+ # Download files in parallel for better performance
515
+ model_paths = []
516
+ download_args = [
517
+ (repo_id, filename, subfolder, temp_dir, hf_token)
518
+ for filename in found_files
519
+ ]
520
+
521
+ with ThreadPoolExecutor(max_workers=4) as executor:
522
+ futures = {executor.submit(download_single_file, args): args[1] for args in download_args}
523
+
524
+ for i, future in enumerate(as_completed(futures)):
525
+ filename = futures[future]
526
+ try:
527
+ path, error = future.result()
528
+ if error:
529
+ raise Exception(f"Failed to download {filename}: {error}")
530
+ model_paths.append(path)
531
+ progress(0.2 + 0.6 * (i + 1) / len(futures),
532
+ desc=f"Downloaded {i+1}/{len(futures)}: {filename}")
533
+ except Exception as e:
534
+ raise e
535
+
536
+ return model_paths, temp_dir
537
+
538
+ elif "*" in filename_pattern:
539
+ # For wildcard patterns, download the entire directory and filter
540
+ progress(0.1, desc="Downloading repository snapshot...")
541
+ local_dir = os.path.join(temp_dir, "download")
542
+ snapshot_download(
543
+ repo_id=repo_id,
544
+ subfolder=subfolder or None,
545
+ local_dir=local_dir,
546
+ token=hf_token,
547
+ resume_download=True
548
+ )
549
+
550
+ # Find files matching the pattern
551
+ if subfolder:
552
+ pattern_dir = os.path.join(local_dir, subfolder)
553
+ else:
554
+ pattern_dir = local_dir
555
+
556
+ model_files = glob.glob(os.path.join(pattern_dir, filename_pattern))
557
+ if not model_files:
558
+ raise ValueError(f"No files found matching pattern: {filename_pattern}")
559
+
560
+ # Limit number of files
561
+ if len(model_files) > 50:
562
+ raise ValueError(f"Too many files found ({len(model_files)}). Please use a more specific pattern.")
563
+
564
+ return model_files, temp_dir
565
+ else:
566
+ # Single file
567
+ progress(0.2, desc=f"Downloading {filename_pattern}...")
568
+ model_path = hf_hub_download(
569
+ repo_id=repo_id,
570
+ filename=filename_pattern,
571
+ subfolder=subfolder or None,
572
+ cache_dir=temp_dir,
573
+ token=hf_token,
574
+ resume_download=True
575
+ )
576
+ return [model_path], temp_dir
577
+ else:
578
+ # For non-safetensors formats
579
+ if "*" in filename_pattern:
580
+ raise ValueError("Wildcards only supported for safetensors format")
581
+ progress(0.2, desc=f"Downloading {filename_pattern}...")
582
+ model_path = hf_hub_download(
583
+ repo_id=repo_id,
584
+ filename=filename_pattern,
585
+ subfolder=subfolder or None,
586
+ cache_dir=temp_dir,
587
+ token=hf_token,
588
+ resume_download=True
589
+ )
590
+ return [model_path], temp_dir
591
+
592
  elif source_type == "modelscope":
593
  if not MODELScope_AVAILABLE:
594
  raise ImportError("ModelScope not installed")
595
  repo_id = repo_url.strip()
596
+
597
+ if model_format == "safetensors" and "*" in filename_pattern:
598
+ # For ModelScope, we need to handle sharded files differently
599
+ # This is a simplified approach - in a real implementation, you might need to list files first
600
+ raise NotImplementedError("Pattern matching for ModelScope sharded files not fully implemented")
601
+ else:
602
+ progress(0.2, desc=f"Downloading {filename_pattern}...")
603
+ model_path = ms_file_download(model_id=repo_id, file_path=filename_pattern)
604
+ return [model_path], temp_dir
605
  else:
606
  raise ValueError("Unknown source")
607
+
608
  except Exception as e:
609
  shutil.rmtree(temp_dir, ignore_errors=True)
610
  raise e
 
781
  def process_and_upload_fp8(
782
  source_type,
783
  repo_url,
784
+ filename_pattern,
785
+ model_format,
786
  fp8_format,
787
  recovery_rules_json,
788
  target_type,
 
817
  output_dir = tempfile.mkdtemp()
818
  try:
819
  progress(0.05, desc="Downloading model...")
820
+ model_paths, temp_dir = download_model_files(
821
+ source_type, repo_url, filename_pattern, model_format, hf_token, progress
822
  )
823
 
824
+ progress(0.8, desc="Converting to FP8 with precision recovery...")
825
+ success, msg, stats, fp8_path, recovery_path = convert_model_to_fp8_with_recovery(
826
+ model_paths, output_dir, fp8_format, recovery_rules, model_format, progress
827
  )
828
 
829
  if not success:
 
835
  )
836
 
837
  # Generate README
838
+ if len(model_paths) == 1:
839
+ original_filename = os.path.basename(model_paths[0])
840
+ else:
841
+ original_filename = f"{len(model_paths)} sharded files"
842
+ # Add the pattern if not auto
843
+ if filename_pattern != "auto":
844
+ original_filename += f" matching '{filename_pattern}'"
845
+
846
  fp8_filename = os.path.basename(fp8_path)
847
  recovery_filename = os.path.basename(recovery_path) if recovery_path else ""
848
 
 
857
  ---
858
  # FP8 Model with Per-Tensor Precision Recovery
859
  - **Source**: `{repo_url}`
860
+ - **Original File(s)**: `{original_filename}`
861
+ - **Original Format**: `{model_format}`
862
  - **FP8 Format**: `{fp8_format.upper()}`
863
  - **FP8 File**: `{fp8_filename}`
864
  - **Recovery File**: `{recovery_filename if recovery_filename else "None"}`
 
865
  ## Recovery Rules Used
866
  ```json
867
  {json.dumps(recovery_rules, indent=2)}
868
  ```
 
869
  ## Usage (Inference)
870
  ```python
871
  from safetensors.torch import load_file
872
  import torch
 
873
  # Load FP8 model
874
  fp8_state = load_file("{fp8_filename}")
 
875
  # Load recovery weights if available
876
  recovery_state = load_file("{recovery_filename}") if "{recovery_filename}" and os.path.exists("{recovery_filename}") else {{}}
 
877
  # Reconstruct high-precision weights
878
  reconstructed = {{}}
879
  for key in fp8_state:
 
896
  fp8_weight = fp8_weight + diff
897
 
898
  reconstructed[key] = fp8_weight
 
899
  # Use reconstructed weights in your model
900
  model.load_state_dict(reconstructed)
901
  ```
 
902
  > **Note**: For best results, use the same recovery configuration during inference as was used during extraction.
903
  > Requires PyTorch β‰₯ 2.1 for FP8 support.
 
904
  ## Statistics
905
  - **Total layers**: {stats['total_layers']}
906
  - **Layers with recovery**: {stats['processed_layers']}
 
952
 
953
  with gr.Blocks(title="Advanced FP8 Quantizer with Per-Tensor Precision Recovery") as demo:
954
  gr.Markdown("# πŸ”„ Advanced FP8 Quantizer with Per-Tensor Precision Recovery")
955
+ gr.Markdown("Convert model files (safetensors, pth, ckpt) β†’ **FP8** + **customizable precision recovery**. Supports any number of sharded files.")
956
 
957
  with gr.Row():
958
  with gr.Column():
959
  source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source")
960
  repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
961
+
962
+ with gr.Row():
963
+ model_format = gr.Dropdown(
964
+ choices=["safetensors", "pth", "pt", "ckpt"],
965
+ value="safetensors",
966
+ label="Model Format"
967
+ )
968
+ filename_pattern = gr.Textbox(
969
+ label="Filename or Pattern",
970
+ placeholder="auto (detects sharded files) or model-*.safetensors",
971
+ value="auto"
972
+ )
973
 
974
  with gr.Accordion("FP8 Settings", open=True):
975
  fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
 
1059
  inputs=[
1060
  source_type,
1061
  repo_url,
1062
+ filename_pattern,
1063
+ model_format,
1064
  fp8_format,
1065
  recovery_rules_json,
1066
  target_type,
 
1078
  [
1079
  "huggingface",
1080
  "https://huggingface.co/stabilityai/sdxl-vae",
1081
+ "auto",
1082
+ "safetensors",
1083
  "e4m3fn",
1084
  generate_default_rules("vae"),
1085
  "huggingface"
 
1087
  [
1088
  "huggingface",
1089
  "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
1090
+ "auto",
1091
+ "safetensors",
1092
  "e5m2",
1093
  generate_default_rules("text_encoder"),
1094
  "huggingface"
 
1096
  [
1097
  "huggingface",
1098
  "https://huggingface.co/Yabo/FramePainter/tree/main",
1099
+ "auto",
1100
+ "safetensors",
1101
  "e5m2",
1102
  generate_default_rules("unet_transformer"),
1103
  "huggingface"
1104
+ ],
1105
+ [
1106
+ "huggingface",
1107
+ "https://huggingface.co/stabilityai/stable-diffusion-2-1",
1108
+ "model-*.safetensors",
1109
+ "safetensors",
1110
+ "e5m2",
1111
+ generate_default_rules("all"),
1112
+ "huggingface"
1113
+ ],
1114
+ [
1115
+ "huggingface",
1116
+ "https://huggingface.co/CompVis/stable-diffusion-v1-4",
1117
+ "sd-v1-4.ckpt",
1118
+ "ckpt",
1119
+ "e5m2",
1120
+ generate_default_rules("all"),
1121
+ "huggingface"
1122
  ]
1123
  ],
1124
+ inputs=[source_type, repo_url, filename_pattern, model_format, fp8_format, recovery_rules_json, target_type],
1125
  label="Example Conversions",
1126
  cache_examples=False
1127
  )
 
1155
  - Always include a catch-all rule at the end
1156
 
1157
  > **Pro Tip for VAE**: Use `"dim": 4` combined with `"key_pattern": "vae"` to reliably target VAE convolutional layers with difference recovery.
1158
+
1159
+ ## πŸ“ File Format Support
1160
+
1161
+ This tool supports multiple model formats:
1162
+
1163
+ - **Safetensors**: Modern, secure format for storing tensors. Supports sharded files (e.g., `model-00001-of-00005.safetensors`).
1164
+ - **PTH/PT**: PyTorch checkpoint files. Can contain state dicts or full model objects.
1165
+ - **CKPT**: Checkpoint files, commonly used for stable diffusion models.
1166
+
1167
+ ### Shard Support:
1168
+ - **Unlimited Shards**: Supports any number of sharded files (2, 5, 10, 20+)
1169
+ - **Auto-Detection**: Automatically finds all shards when using "auto" pattern
1170
+ - **Parallel Downloads**: Downloads multiple shards simultaneously for faster processing
1171
+ - **Memory Efficient**: Processes shards one at a time to manage memory usage
1172
+ - **Progress Tracking**: Shows detailed progress for each shard download and processing
1173
+
1174
+ ### Filename Patterns:
1175
+ - **Auto-detection**: Use "auto" to automatically find all sharded safetensors files
1176
+ - **Wildcard patterns**: Use `model-*.safetensors` to match sharded files
1177
+ - **Specific file**: Use exact filename for single files
1178
+
1179
+ For models with many shards (e.g., 5+ files), the tool will:
1180
+ 1. Automatically detect all shards
1181
+ 2. Download them in parallel (up to 4 simultaneous downloads)
1182
+ 3. Load them sequentially to manage memory
1183
+ 4. Merge them into a single FP8 model
1184
  """)
1185
 
1186
  demo.launch()