Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import tempfile | |
| import shutil | |
| import re | |
| import json | |
| from pathlib import Path | |
| from huggingface_hub import HfApi, hf_hub_download, snapshot_download, list_repo_files | |
| from safetensors.torch import load_file, save_file | |
| import torch | |
| import torch.nn.functional as F | |
| import traceback | |
| import glob | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| try: | |
| from modelscope.hub.file_download import model_file_download as ms_file_download | |
| from modelscope.hub.api import HubApi as ModelScopeApi | |
| MODELScope_AVAILABLE = True | |
| except ImportError: | |
| MODELScope_AVAILABLE = False | |
| def load_model_files(model_paths, model_format="safetensors", progress_callback=None): | |
| """ | |
| Load model weights from one or more files, supporting sharded safetensors and other formats. | |
| """ | |
| state_dict = {} | |
| if model_format == "safetensors": | |
| # Handle sharded safetensors files | |
| for i, path in enumerate(model_paths): | |
| if progress_callback: | |
| progress_callback(f"Loading shard {i+1}/{len(model_paths)}: {os.path.basename(path)}") | |
| part_dict = load_file(path) | |
| state_dict.update(part_dict) | |
| elif model_format in ["pth", "pt"]: | |
| # PyTorch checkpoint files | |
| for i, path in enumerate(model_paths): | |
| if progress_callback: | |
| progress_callback(f"Loading checkpoint {i+1}/{len(model_paths)}: {os.path.basename(path)}") | |
| checkpoint = torch.load(path, map_location="cpu") | |
| if isinstance(checkpoint, dict): | |
| # Try to extract state dict from checkpoint | |
| if "state_dict" in checkpoint: | |
| state_dict.update(checkpoint["state_dict"]) | |
| elif "model_state_dict" in checkpoint: | |
| state_dict.update(checkpoint["model_state_dict"]) | |
| elif "model" in checkpoint and isinstance(checkpoint["model"], dict): | |
| state_dict.update(checkpoint["model"]) | |
| else: | |
| # Assume the checkpoint itself is the state dict | |
| state_dict.update(checkpoint) | |
| elif model_format == "ckpt": | |
| # Checkpoint files (similar to pth) | |
| for i, path in enumerate(model_paths): | |
| if progress_callback: | |
| progress_callback(f"Loading checkpoint {i+1}/{len(model_paths)}: {os.path.basename(path)}") | |
| checkpoint = torch.load(path, map_location="cpu") | |
| if isinstance(checkpoint, dict): | |
| if "state_dict" in checkpoint: | |
| state_dict.update(checkpoint["state_dict"]) | |
| elif "model_state_dict" in checkpoint: | |
| state_dict.update(checkpoint["model_state_dict"]) | |
| elif "model" in checkpoint and isinstance(checkpoint["model"], dict): | |
| state_dict.update(checkpoint["model"]) | |
| else: | |
| state_dict.update(checkpoint) | |
| return state_dict | |
| def read_model_metadata(model_paths, model_format="safetensors"): | |
| """Read metadata from model files.""" | |
| metadata = {} | |
| if model_format == "safetensors": | |
| # Read metadata from the first safetensors file | |
| if model_paths: | |
| with open(model_paths[0], 'rb') as f: | |
| header_size = int.from_bytes(f.read(8), 'little') | |
| header_json = f.read(header_size).decode('utf-8') | |
| header = json.loads(header_json) | |
| metadata = header.get('__metadata__', {}) | |
| elif model_format in ["pth", "pt", "ckpt"]: | |
| # Try to extract metadata from checkpoint files | |
| if model_paths: | |
| checkpoint = torch.load(model_paths[0], map_location="cpu") | |
| if isinstance(checkpoint, dict): | |
| # Look for common metadata keys | |
| for key in ["hyperparameters", "args", "config", "metadata"]: | |
| if key in checkpoint: | |
| metadata[key] = checkpoint[key] | |
| return metadata | |
| def extract_base_name_from_sharded_files(model_paths): | |
| """Extract a common base name from sharded files.""" | |
| if not model_paths: | |
| return "model" | |
| if len(model_paths) == 1: | |
| # Single file case | |
| base_name = os.path.splitext(os.path.basename(model_paths[0]))[0] | |
| # Remove common suffixes | |
| for suffix in ["-fp8", "-fp16", "-bf16", "-32", "-16"]: | |
| if base_name.endswith(suffix): | |
| base_name = base_name[:-len(suffix)] | |
| return base_name | |
| # Multiple files case - find common prefix | |
| base_names = [os.path.splitext(os.path.basename(p))[0] for p in model_paths] | |
| # Handle Hugging Face pattern: model-00001-of-00002.safetensors | |
| # Extract the part before the shard numbering | |
| if all("-of-" in name for name in base_names): | |
| # All files follow the "model-XXXXX-of-YYYYY" pattern | |
| common_parts = [] | |
| for name in base_names: | |
| # Split at the shard numbering | |
| parts = name.split("-") | |
| if len(parts) >= 3 and parts[-2].isdigit() and parts[-1].startswith("of"): | |
| # Remove the last two parts (shard number and total) | |
| common_part = "-".join(parts[:-2]) | |
| common_parts.append(common_part) | |
| else: | |
| common_parts.append(name) | |
| # Use the most common base name | |
| from collections import Counter | |
| base_name = Counter(common_parts).most_common(1)[0][0] | |
| return base_name | |
| # Fallback: find common prefix | |
| common_prefix = "" | |
| for chars in zip(*base_names): | |
| if len(set(chars)) == 1: | |
| common_prefix += chars[0] | |
| else: | |
| break | |
| # Clean up the common prefix | |
| base_name = re.sub(r'[-_]+$', '', common_prefix) | |
| if not base_name: | |
| base_name = "model" | |
| return base_name | |
| def convert_model_to_fp8(model_paths, output_dir, fp8_format, | |
| model_format="safetensors", progress=gr.Progress()): | |
| """Simple and fast FP8 conversion without recovery strategies.""" | |
| progress(0.05, desc=f"Starting FP8 conversion for {model_format}...") | |
| try: | |
| metadata = read_model_metadata(model_paths, model_format) | |
| progress(0.1, desc="Loaded metadata.") | |
| # Load model with progress tracking | |
| state_dict = load_model_files( | |
| model_paths, | |
| model_format, | |
| progress_callback=lambda msg: progress(0.15, desc=msg) | |
| ) | |
| progress(0.25, desc=f"Loaded {len(model_paths)} model files with {len(state_dict)} tensors.") | |
| # Setup FP8 format | |
| fp8_dtype = torch.float8_e5m2 if fp8_format == "e5m2" else torch.float8_e4m3fn | |
| # Initialize outputs | |
| sd_fp8 = {} | |
| conversion_stats = { | |
| "total_tensors": len(state_dict), | |
| "converted_tensors": 0, | |
| "skipped_tensors": 0, | |
| "skipped_reasons": [] | |
| } | |
| # Process each tensor | |
| total = len(state_dict) | |
| for i, key in enumerate(state_dict): | |
| if i % 100 == 0: # Update progress every 100 tensors for speed | |
| progress(0.3 + 0.6 * (i / total), desc=f"Converting {i}/{total} tensors...") | |
| weight = state_dict[key] | |
| # Convert only float tensors to FP8 | |
| if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]: | |
| fp8_weight = weight.to(fp8_dtype) | |
| sd_fp8[key] = fp8_weight | |
| conversion_stats["converted_tensors"] += 1 | |
| else: | |
| # Keep non-float tensors as-is (e.g., ints, bools) | |
| sd_fp8[key] = weight | |
| conversion_stats["skipped_tensors"] += 1 | |
| conversion_stats["skipped_reasons"].append(f"{key}: {weight.dtype}") | |
| # Extract base name for output files | |
| base_name = extract_base_name_from_sharded_files(model_paths) | |
| # Save FP8 model | |
| fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors") | |
| save_file(sd_fp8, fp8_path, metadata={ | |
| "format": model_format, | |
| "fp8_format": fp8_format, | |
| "original_files": str(len(model_paths)), | |
| "conversion_stats": json.dumps(conversion_stats), | |
| **metadata | |
| }) | |
| progress(0.95, desc="Saved FP8 file.") | |
| # Generate stats message | |
| stats_msg = f"β FP8 ({fp8_format}) conversion complete!\n" | |
| stats_msg += f"- Total tensors: {conversion_stats['total_tensors']}\n" | |
| stats_msg += f"- Converted to FP8: {conversion_stats['converted_tensors']}\n" | |
| stats_msg += f"- Skipped (non-float): {conversion_stats['skipped_tensors']}\n" | |
| stats_msg += f"- Output file: {os.path.basename(fp8_path)}\n" | |
| if conversion_stats["skipped_tensors"] > 0: | |
| stats_msg += "\nβ οΈ Some tensors were skipped (non-float types):\n" | |
| for i, reason in enumerate(conversion_stats["skipped_reasons"][:5]): # Show first 5 | |
| stats_msg += f" - {reason}\n" | |
| if len(conversion_stats["skipped_reasons"]) > 5: | |
| stats_msg += f" - ... and {len(conversion_stats['skipped_reasons']) - 5} more\n" | |
| progress(1.0, desc="β FP8 conversion complete!") | |
| return True, stats_msg, conversion_stats, fp8_path, None | |
| except Exception as e: | |
| traceback.print_exc() | |
| return False, str(e), None, None, None | |
| def parse_hf_url(url): | |
| url = url.strip().rstrip("/") | |
| if not url.startswith("https://huggingface.co/"): | |
| raise ValueError("URL must start with https://huggingface.co/") | |
| path = url.replace("https://huggingface.co/", "") | |
| parts = path.split("/") | |
| if len(parts) < 2: | |
| raise ValueError("Invalid repo format") | |
| repo_id = "/".join(parts[:2]) | |
| subfolder = "" | |
| if len(parts) > 3 and parts[2] == "tree": | |
| subfolder = "/".join(parts[4:]) if len(parts) > 4 else "" | |
| elif len(parts) > 2: | |
| subfolder = "/".join(parts[2:]) | |
| return repo_id, subfolder | |
| def download_single_file(args): | |
| """Helper function for parallel downloads.""" | |
| repo_id, filename, subfolder, cache_dir, token = args | |
| try: | |
| path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| subfolder=subfolder, | |
| cache_dir=cache_dir, | |
| token=token, | |
| resume_download=True | |
| ) | |
| return path, None | |
| except Exception as e: | |
| return None, str(e) | |
| def find_sharded_safetensors_files(repo_id, subfolder=None, hf_token=None, max_shards=50): | |
| """Find all sharded safetensors files in a repository.""" | |
| try: | |
| # List all files in the repository | |
| repo_files = list_repo_files(repo_id, repo_type="model", token=hf_token) | |
| # Filter for safetensors files in the subfolder | |
| if subfolder: | |
| pattern = f"{subfolder}/" if not subfolder.endswith("/") else subfolder | |
| safetensors_files = [f for f in repo_files if f.endswith('.safetensors') and f.startswith(pattern)] | |
| # Remove subfolder prefix | |
| safetensors_files = [f[len(pattern):] for f in safetensors_files if len(f) > len(pattern)] | |
| else: | |
| safetensors_files = [f for f in repo_files if f.endswith('.safetensors')] | |
| # Check if files follow sharding pattern | |
| sharded_files = [] | |
| single_files = [] | |
| for f in safetensors_files: | |
| # Check for sharding pattern: model-XXXXX-of-YYYYY.safetensors | |
| match = re.search(r'-\d{5}-of-\d{5}\.safetensors$', f) | |
| if match: | |
| sharded_files.append(f) | |
| else: | |
| single_files.append(f) | |
| # If we have sharded files, return them sorted by shard number | |
| if sharded_files: | |
| # Sort by shard number for consistent ordering | |
| def extract_shard_num(filename): | |
| match = re.search(r'-(\d{5})-of-\d{5}\.safetensors$', filename) | |
| return int(match.group(1)) if match else 0 | |
| sharded_files.sort(key=extract_shard_num) | |
| # Limit number of shards to prevent accidental downloads of huge models | |
| if len(sharded_files) > max_shards: | |
| raise ValueError(f"Too many shards found ({len(sharded_files)}). Maximum allowed is {max_shards}. " | |
| f"Please specify a more specific pattern.") | |
| return sharded_files | |
| elif single_files: | |
| # Return single files (non-sharded) | |
| return single_files | |
| else: | |
| return [] | |
| except Exception as e: | |
| print(f"Error listing repository files: {e}") | |
| return [] | |
| def download_model_files(source_type, repo_url, filename_pattern, model_format, hf_token=None, progress=gr.Progress()): | |
| temp_dir = tempfile.mkdtemp() | |
| try: | |
| if source_type == "huggingface": | |
| repo_id, subfolder = parse_hf_url(repo_url) | |
| if model_format == "safetensors": | |
| # Handle different patterns for safetensors | |
| if filename_pattern == "auto" or filename_pattern == "": | |
| # Auto-detect sharded files | |
| progress(0.1, desc="Discovering model files...") | |
| found_files = find_sharded_safetensors_files(repo_id, subfolder, hf_token) | |
| if not found_files: | |
| raise ValueError("No safetensors files found in repository") | |
| progress(0.2, desc=f"Found {len(found_files)} shard(s). Downloading...") | |
| # Download files in parallel for better performance | |
| model_paths = [] | |
| download_args = [ | |
| (repo_id, filename, subfolder, temp_dir, hf_token) | |
| for filename in found_files | |
| ] | |
| with ThreadPoolExecutor(max_workers=4) as executor: | |
| futures = {executor.submit(download_single_file, args): args[1] for args in download_args} | |
| for i, future in enumerate(as_completed(futures)): | |
| filename = futures[future] | |
| try: | |
| path, error = future.result() | |
| if error: | |
| raise Exception(f"Failed to download {filename}: {error}") | |
| model_paths.append(path) | |
| progress(0.2 + 0.6 * (i + 1) / len(futures), | |
| desc=f"Downloaded {i+1}/{len(futures)}: {filename}") | |
| except Exception as e: | |
| raise e | |
| return model_paths, temp_dir | |
| elif "*" in filename_pattern: | |
| # For wildcard patterns, download the entire directory and filter | |
| progress(0.1, desc="Downloading repository snapshot...") | |
| local_dir = os.path.join(temp_dir, "download") | |
| snapshot_download( | |
| repo_id=repo_id, | |
| subfolder=subfolder or None, | |
| local_dir=local_dir, | |
| token=hf_token, | |
| resume_download=True | |
| ) | |
| # Find files matching the pattern | |
| if subfolder: | |
| pattern_dir = os.path.join(local_dir, subfolder) | |
| else: | |
| pattern_dir = local_dir | |
| model_files = glob.glob(os.path.join(pattern_dir, filename_pattern)) | |
| if not model_files: | |
| raise ValueError(f"No files found matching pattern: {filename_pattern}") | |
| # Limit number of files | |
| if len(model_files) > 50: | |
| raise ValueError(f"Too many files found ({len(model_files)}). Please use a more specific pattern.") | |
| return model_files, temp_dir | |
| else: | |
| # SINGLE FILE SAFETENSORS - separate from shard discovery | |
| progress(0.2, desc=f"Downloading {filename_pattern}...") | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename_pattern, | |
| subfolder=subfolder or None, | |
| cache_dir=temp_dir, | |
| token=hf_token, | |
| resume_download=True | |
| ) | |
| return [model_path], temp_dir | |
| else: | |
| # For non-safetensors formats | |
| if "*" in filename_pattern: | |
| raise ValueError("Wildcards only supported for safetensors format") | |
| progress(0.2, desc=f"Downloading {filename_pattern}...") | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename_pattern, | |
| subfolder=subfolder or None, | |
| cache_dir=temp_dir, | |
| token=hf_token, | |
| resume_download=True | |
| ) | |
| return [model_path], temp_dir | |
| elif source_type == "modelscope": | |
| if not MODELScope_AVAILABLE: | |
| raise ImportError("ModelScope not installed") | |
| repo_id = repo_url.strip() | |
| if model_format == "safetensors" and "*" in filename_pattern: | |
| # For ModelScope, we need to handle sharded files differently | |
| # This is a simplified approach - in a real implementation, you might need to list files first | |
| raise NotImplementedError("Pattern matching for ModelScope sharded files not fully implemented") | |
| else: | |
| progress(0.2, desc=f"Downloading {filename_pattern}...") | |
| model_path = ms_file_download(model_id=repo_id, file_path=filename_pattern) | |
| return [model_path], temp_dir | |
| else: | |
| raise ValueError("Unknown source") | |
| except Exception as e: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| raise e | |
| def upload_to_target(target_type, new_repo_id, output_dir, fp8_format, hf_token=None, modelscope_token=None, private_repo=False): | |
| if target_type == "huggingface": | |
| api = HfApi(token=hf_token) | |
| api.create_repo(repo_id=new_repo_id, private=private_repo, repo_type="model", exist_ok=True) | |
| api.upload_folder(repo_id=new_repo_id, folder_path=output_dir, repo_type="model", token=hf_token) | |
| return f"https://huggingface.co/{new_repo_id}" | |
| elif target_type == "modelscope": | |
| api = ModelScopeApi() | |
| if modelscope_token: | |
| api.login(modelscope_token) | |
| api.push_model(model_id=new_repo_id, model_dir=output_dir) | |
| return f"https://modelscope.cn/models/{new_repo_id}" | |
| else: | |
| raise ValueError("Unknown target") | |
| def process_and_upload_fp8( | |
| source_type, | |
| repo_url, | |
| filename_pattern, | |
| model_format, | |
| fp8_format, | |
| target_type, | |
| new_repo_id, | |
| hf_token, | |
| modelscope_token, | |
| private_repo, | |
| progress=gr.Progress() | |
| ): | |
| if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id): | |
| return None, "β Invalid repo ID format. Use 'username/model-name'.", "", "" | |
| if source_type == "huggingface" and not hf_token: | |
| return None, "β Hugging Face token required for source.", "", "" | |
| if target_type == "huggingface" and not hf_token: | |
| return None, "β Hugging Face token required for target.", "", "" | |
| temp_dir = None | |
| output_dir = tempfile.mkdtemp() | |
| try: | |
| progress(0.05, desc="Downloading model...") | |
| model_paths, temp_dir = download_model_files( | |
| source_type, repo_url, filename_pattern, model_format, hf_token, progress | |
| ) | |
| progress(0.8, desc="Converting to FP8...") | |
| success, msg, stats, fp8_path, _ = convert_model_to_fp8( | |
| model_paths, output_dir, fp8_format, model_format, progress | |
| ) | |
| if not success: | |
| return None, f"β Conversion failed: {msg}", "", "" | |
| progress(0.9, desc="Uploading...") | |
| repo_url_final = upload_to_target( | |
| target_type, new_repo_id, output_dir, fp8_format, hf_token, modelscope_token, private_repo | |
| ) | |
| # Generate README | |
| if len(model_paths) == 1: | |
| original_filename = os.path.basename(model_paths[0]) | |
| else: | |
| original_filename = f"{len(model_paths)} sharded files" | |
| # Add the pattern if not auto | |
| if filename_pattern != "auto": | |
| original_filename += f" matching '{filename_pattern}'" | |
| fp8_filename = os.path.basename(fp8_path) | |
| readme = f"""--- | |
| library_name: diffusers | |
| tags: | |
| - fp8 | |
| - safetensors | |
| - converted-by-gradio | |
| --- | |
| # FP8 Model Conversion | |
| - **Source**: `{repo_url}` | |
| - **Original File(s)**: `{original_filename}` | |
| - **Original Format**: `{model_format}` | |
| - **FP8 Format**: `{fp8_format.upper()}` | |
| - **FP8 File**: `{fp8_filename}` | |
| ## Usage | |
| ```python | |
| from safetensors.torch import load_file | |
| import torch | |
| # Load FP8 model | |
| fp8_state = load_file("{fp8_filename}") | |
| # Convert tensors back to float32 for computation (auto-converted by PyTorch) | |
| model.load_state_dict(fp8_state) | |
| ``` | |
| > **Note**: FP8 tensors are automatically converted to float32 when loaded in PyTorch. | |
| > Requires PyTorch β₯ 2.1 for FP8 support. | |
| ## Statistics | |
| - **Total tensors**: {stats['total_tensors']} | |
| - **Converted to FP8**: {stats['converted_tensors']} | |
| - **Skipped (non-float)**: {stats['skipped_tensors']} | |
| """ | |
| with open(os.path.join(output_dir, "README.md"), "w") as f: | |
| f.write(readme) | |
| if target_type == "huggingface": | |
| HfApi(token=hf_token).upload_file( | |
| path_or_fileobj=os.path.join(output_dir, "README.md"), | |
| path_in_repo="README.md", | |
| repo_id=new_repo_id, | |
| repo_type="model", | |
| token=hf_token | |
| ) | |
| progress(1.0, desc="β Done!") | |
| # Generate result HTML | |
| result_html = f""" | |
| β Success! | |
| Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a> | |
| - FP8 model: `{fp8_filename}` | |
| - Converted {stats['converted_tensors']} tensors to {fp8_format.upper()} | |
| """ | |
| return (gr.HTML(result_html), | |
| "β FP8 conversion successful!", | |
| msg, | |
| "") | |
| except Exception as e: | |
| traceback.print_exc() | |
| return None, f"β Error: {str(e)}", "", "" | |
| finally: | |
| if temp_dir: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| shutil.rmtree(output_dir, ignore_errors=True) | |
| with gr.Blocks(title="Fast FP8 Model Converter") as demo: | |
| gr.Markdown("# β‘ Fast FP8 Model Converter") | |
| gr.Markdown("Convert model files (safetensors, pth, ckpt) β **FP8**. Supports sharded files with auto-discovery. Simple and fast!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source") | |
| repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id") | |
| with gr.Row(): | |
| model_format = gr.Dropdown( | |
| choices=["safetensors", "pth", "pt", "ckpt"], | |
| value="safetensors", | |
| label="Model Format" | |
| ) | |
| filename_pattern = gr.Textbox( | |
| label="Filename or Pattern", | |
| placeholder="auto (detects sharded files) or model-*.safetensors", | |
| value="auto" | |
| ) | |
| with gr.Accordion("FP8 Settings", open=True): | |
| fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format") | |
| with gr.Accordion("Authentication", open=False): | |
| hf_token = gr.Textbox(label="Hugging Face Token", type="password") | |
| modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE) | |
| with gr.Column(): | |
| target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target") | |
| new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8") | |
| private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False) | |
| status_output = gr.Markdown() | |
| detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10) | |
| recovery_summary = gr.Textbox(label="Additional Info", interactive=False, lines=3) | |
| convert_btn = gr.Button("π Convert & Upload", variant="primary") | |
| repo_link_output = gr.HTML() | |
| convert_btn.click( | |
| fn=process_and_upload_fp8, | |
| inputs=[ | |
| source_type, | |
| repo_url, | |
| filename_pattern, | |
| model_format, | |
| fp8_format, | |
| target_type, | |
| new_repo_id, | |
| hf_token, | |
| modelscope_token, | |
| private_repo | |
| ], | |
| outputs=[repo_link_output, status_output, detailed_log, recovery_summary], | |
| show_progress=True | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "huggingface", | |
| "https://huggingface.co/stabilityai/sdxl-vae", | |
| "auto", | |
| "safetensors", | |
| "e4m3fn", | |
| "huggingface" | |
| ], | |
| [ | |
| "huggingface", | |
| "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder", | |
| "auto", | |
| "safetensors", | |
| "e5m2", | |
| "huggingface" | |
| ], | |
| [ | |
| "huggingface", | |
| "https://huggingface.co/Yabo/FramePainter/tree/main", | |
| "auto", | |
| "safetensors", | |
| "e5m2", | |
| "huggingface" | |
| ], | |
| [ | |
| "huggingface", | |
| "https://huggingface.co/stabilityai/stable-diffusion-2-1", | |
| "model-*.safetensors", | |
| "safetensors", | |
| "e5m2", | |
| "huggingface" | |
| ], | |
| [ | |
| "huggingface", | |
| "https://huggingface.co/CompVis/stable-diffusion-v1-4", | |
| "sd-v1-4.ckpt", | |
| "ckpt", | |
| "e5m2", | |
| "huggingface" | |
| ] | |
| ], | |
| inputs=[source_type, repo_url, filename_pattern, model_format, fp8_format, target_type], | |
| label="Example Conversions", | |
| cache_examples=False | |
| ) | |
| gr.Markdown(""" | |
| ## π Fast FP8 Conversion Tool | |
| This tool provides **fast and simple FP8 conversion** for various model formats: | |
| ### **Supported Formats:** | |
| - **Safetensors**: Modern, secure format. Supports sharded files (e.g., `model-00001-of-00005.safetensors`) | |
| - **PTH/PT**: PyTorch checkpoint files | |
| - **CKPT**: Checkpoint files (commonly used for stable diffusion models) | |
| ### **Shard Support:** | |
| - **Unlimited Shards**: Supports any number of sharded files (2, 5, 10, 20+) | |
| - **Auto-Detection**: Automatically finds all shards when using "auto" pattern | |
| - **Parallel Downloads**: Downloads multiple shards simultaneously (up to 4 at once) | |
| - **Memory Efficient**: Processes files efficiently to manage memory | |
| ### **Performance Features:** | |
| - **Fast Conversion**: Simple dtype conversion without complex recovery strategies | |
| - **Batch Processing**: Processes tensors in batches for better performance | |
| - **Progress Tracking**: Shows detailed progress for each step | |
| ### **How It Works:** | |
| 1. **Discovery**: Automatically detects sharded files or uses your specified pattern | |
| 2. **Download**: Downloads files in parallel for maximum speed | |
| 3. **Conversion**: Converts float tensors to FP8, leaves other types unchanged | |
| 4. **Upload**: Uploads the converted model to your target repository | |
| ### **Usage Tips:** | |
| - Use "auto" pattern to automatically detect all sharded safetensors files | |
| - Use `model-*.safetensors` to match specific shard patterns | |
| - For single files, just enter the filename (e.g., `model.safetensors`) | |
| - FP8 conversion reduces model size by ~4x compared to FP32 | |
| - FP8 tensors are automatically converted to float32 when loaded in PyTorch | |
| > **Note**: This is a simple conversion tool. For precision recovery options, use the advanced version. | |
| """) | |
| demo.launch() |