import gradio as gr import os import tempfile import shutil import re import json from pathlib import Path from huggingface_hub import HfApi, hf_hub_download, snapshot_download, list_repo_files from safetensors.torch import load_file, save_file import torch import torch.nn.functional as F import traceback import glob import time from concurrent.futures import ThreadPoolExecutor, as_completed try: from modelscope.hub.file_download import model_file_download as ms_file_download from modelscope.hub.api import HubApi as ModelScopeApi MODELScope_AVAILABLE = True except ImportError: MODELScope_AVAILABLE = False def load_model_files(model_paths, model_format="safetensors", progress_callback=None): """ Load model weights from one or more files, supporting sharded safetensors and other formats. """ state_dict = {} if model_format == "safetensors": # Handle sharded safetensors files for i, path in enumerate(model_paths): if progress_callback: progress_callback(f"Loading shard {i+1}/{len(model_paths)}: {os.path.basename(path)}") part_dict = load_file(path) state_dict.update(part_dict) elif model_format in ["pth", "pt"]: # PyTorch checkpoint files for i, path in enumerate(model_paths): if progress_callback: progress_callback(f"Loading checkpoint {i+1}/{len(model_paths)}: {os.path.basename(path)}") checkpoint = torch.load(path, map_location="cpu") if isinstance(checkpoint, dict): # Try to extract state dict from checkpoint if "state_dict" in checkpoint: state_dict.update(checkpoint["state_dict"]) elif "model_state_dict" in checkpoint: state_dict.update(checkpoint["model_state_dict"]) elif "model" in checkpoint and isinstance(checkpoint["model"], dict): state_dict.update(checkpoint["model"]) else: # Assume the checkpoint itself is the state dict state_dict.update(checkpoint) elif model_format == "ckpt": # Checkpoint files (similar to pth) for i, path in enumerate(model_paths): if progress_callback: progress_callback(f"Loading checkpoint {i+1}/{len(model_paths)}: {os.path.basename(path)}") checkpoint = torch.load(path, map_location="cpu") if isinstance(checkpoint, dict): if "state_dict" in checkpoint: state_dict.update(checkpoint["state_dict"]) elif "model_state_dict" in checkpoint: state_dict.update(checkpoint["model_state_dict"]) elif "model" in checkpoint and isinstance(checkpoint["model"], dict): state_dict.update(checkpoint["model"]) else: state_dict.update(checkpoint) return state_dict def read_model_metadata(model_paths, model_format="safetensors"): """Read metadata from model files.""" metadata = {} if model_format == "safetensors": # Read metadata from the first safetensors file if model_paths: with open(model_paths[0], 'rb') as f: header_size = int.from_bytes(f.read(8), 'little') header_json = f.read(header_size).decode('utf-8') header = json.loads(header_json) metadata = header.get('__metadata__', {}) elif model_format in ["pth", "pt", "ckpt"]: # Try to extract metadata from checkpoint files if model_paths: checkpoint = torch.load(model_paths[0], map_location="cpu") if isinstance(checkpoint, dict): # Look for common metadata keys for key in ["hyperparameters", "args", "config", "metadata"]: if key in checkpoint: metadata[key] = checkpoint[key] return metadata def extract_base_name_from_sharded_files(model_paths): """Extract a common base name from sharded files.""" if not model_paths: return "model" if len(model_paths) == 1: # Single file case base_name = os.path.splitext(os.path.basename(model_paths[0]))[0] # Remove common suffixes for suffix in ["-fp8", "-fp16", "-bf16", "-32", "-16"]: if base_name.endswith(suffix): base_name = base_name[:-len(suffix)] return base_name # Multiple files case - find common prefix base_names = [os.path.splitext(os.path.basename(p))[0] for p in model_paths] # Handle Hugging Face pattern: model-00001-of-00002.safetensors # Extract the part before the shard numbering if all("-of-" in name for name in base_names): # All files follow the "model-XXXXX-of-YYYYY" pattern common_parts = [] for name in base_names: # Split at the shard numbering parts = name.split("-") if len(parts) >= 3 and parts[-2].isdigit() and parts[-1].startswith("of"): # Remove the last two parts (shard number and total) common_part = "-".join(parts[:-2]) common_parts.append(common_part) else: common_parts.append(name) # Use the most common base name from collections import Counter base_name = Counter(common_parts).most_common(1)[0][0] return base_name # Fallback: find common prefix common_prefix = "" for chars in zip(*base_names): if len(set(chars)) == 1: common_prefix += chars[0] else: break # Clean up the common prefix base_name = re.sub(r'[-_]+$', '', common_prefix) if not base_name: base_name = "model" return base_name def convert_model_to_fp8(model_paths, output_dir, fp8_format, model_format="safetensors", progress=gr.Progress()): """Simple and fast FP8 conversion without recovery strategies.""" progress(0.05, desc=f"Starting FP8 conversion for {model_format}...") try: metadata = read_model_metadata(model_paths, model_format) progress(0.1, desc="Loaded metadata.") # Load model with progress tracking state_dict = load_model_files( model_paths, model_format, progress_callback=lambda msg: progress(0.15, desc=msg) ) progress(0.25, desc=f"Loaded {len(model_paths)} model files with {len(state_dict)} tensors.") # Setup FP8 format fp8_dtype = torch.float8_e5m2 if fp8_format == "e5m2" else torch.float8_e4m3fn # Initialize outputs sd_fp8 = {} conversion_stats = { "total_tensors": len(state_dict), "converted_tensors": 0, "skipped_tensors": 0, "skipped_reasons": [] } # Process each tensor total = len(state_dict) for i, key in enumerate(state_dict): if i % 100 == 0: # Update progress every 100 tensors for speed progress(0.3 + 0.6 * (i / total), desc=f"Converting {i}/{total} tensors...") weight = state_dict[key] # Convert only float tensors to FP8 if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]: fp8_weight = weight.to(fp8_dtype) sd_fp8[key] = fp8_weight conversion_stats["converted_tensors"] += 1 else: # Keep non-float tensors as-is (e.g., ints, bools) sd_fp8[key] = weight conversion_stats["skipped_tensors"] += 1 conversion_stats["skipped_reasons"].append(f"{key}: {weight.dtype}") # Extract base name for output files base_name = extract_base_name_from_sharded_files(model_paths) # Save FP8 model fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors") save_file(sd_fp8, fp8_path, metadata={ "format": model_format, "fp8_format": fp8_format, "original_files": str(len(model_paths)), "conversion_stats": json.dumps(conversion_stats), **metadata }) progress(0.95, desc="Saved FP8 file.") # Generate stats message stats_msg = f"✅ FP8 ({fp8_format}) conversion complete!\n" stats_msg += f"- Total tensors: {conversion_stats['total_tensors']}\n" stats_msg += f"- Converted to FP8: {conversion_stats['converted_tensors']}\n" stats_msg += f"- Skipped (non-float): {conversion_stats['skipped_tensors']}\n" stats_msg += f"- Output file: {os.path.basename(fp8_path)}\n" if conversion_stats["skipped_tensors"] > 0: stats_msg += "\n⚠️ Some tensors were skipped (non-float types):\n" for i, reason in enumerate(conversion_stats["skipped_reasons"][:5]): # Show first 5 stats_msg += f" - {reason}\n" if len(conversion_stats["skipped_reasons"]) > 5: stats_msg += f" - ... and {len(conversion_stats['skipped_reasons']) - 5} more\n" progress(1.0, desc="✅ FP8 conversion complete!") return True, stats_msg, conversion_stats, fp8_path, None except Exception as e: traceback.print_exc() return False, str(e), None, None, None def parse_hf_url(url): url = url.strip().rstrip("/") if not url.startswith("https://huggingface.co/"): raise ValueError("URL must start with https://huggingface.co/") path = url.replace("https://huggingface.co/", "") parts = path.split("/") if len(parts) < 2: raise ValueError("Invalid repo format") repo_id = "/".join(parts[:2]) subfolder = "" if len(parts) > 3 and parts[2] == "tree": subfolder = "/".join(parts[4:]) if len(parts) > 4 else "" elif len(parts) > 2: subfolder = "/".join(parts[2:]) return repo_id, subfolder def download_single_file(args): """Helper function for parallel downloads.""" repo_id, filename, subfolder, cache_dir, token = args try: path = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, cache_dir=cache_dir, token=token, resume_download=True ) return path, None except Exception as e: return None, str(e) def find_sharded_safetensors_files(repo_id, subfolder=None, hf_token=None, max_shards=50): """Find all sharded safetensors files in a repository.""" try: # List all files in the repository repo_files = list_repo_files(repo_id, repo_type="model", token=hf_token) # Filter for safetensors files in the subfolder if subfolder: pattern = f"{subfolder}/" if not subfolder.endswith("/") else subfolder safetensors_files = [f for f in repo_files if f.endswith('.safetensors') and f.startswith(pattern)] # Remove subfolder prefix safetensors_files = [f[len(pattern):] for f in safetensors_files if len(f) > len(pattern)] else: safetensors_files = [f for f in repo_files if f.endswith('.safetensors')] # Check if files follow sharding pattern sharded_files = [] single_files = [] for f in safetensors_files: # Check for sharding pattern: model-XXXXX-of-YYYYY.safetensors match = re.search(r'-\d{5}-of-\d{5}\.safetensors$', f) if match: sharded_files.append(f) else: single_files.append(f) # If we have sharded files, return them sorted by shard number if sharded_files: # Sort by shard number for consistent ordering def extract_shard_num(filename): match = re.search(r'-(\d{5})-of-\d{5}\.safetensors$', filename) return int(match.group(1)) if match else 0 sharded_files.sort(key=extract_shard_num) # Limit number of shards to prevent accidental downloads of huge models if len(sharded_files) > max_shards: raise ValueError(f"Too many shards found ({len(sharded_files)}). Maximum allowed is {max_shards}. " f"Please specify a more specific pattern.") return sharded_files elif single_files: # Return single files (non-sharded) return single_files else: return [] except Exception as e: print(f"Error listing repository files: {e}") return [] def download_model_files(source_type, repo_url, filename_pattern, model_format, hf_token=None, progress=gr.Progress()): temp_dir = tempfile.mkdtemp() try: if source_type == "huggingface": repo_id, subfolder = parse_hf_url(repo_url) if model_format == "safetensors": # Handle different patterns for safetensors if filename_pattern == "auto" or filename_pattern == "": # Auto-detect sharded files progress(0.1, desc="Discovering model files...") found_files = find_sharded_safetensors_files(repo_id, subfolder, hf_token) if not found_files: raise ValueError("No safetensors files found in repository") progress(0.2, desc=f"Found {len(found_files)} shard(s). Downloading...") # Download files in parallel for better performance model_paths = [] download_args = [ (repo_id, filename, subfolder, temp_dir, hf_token) for filename in found_files ] with ThreadPoolExecutor(max_workers=4) as executor: futures = {executor.submit(download_single_file, args): args[1] for args in download_args} for i, future in enumerate(as_completed(futures)): filename = futures[future] try: path, error = future.result() if error: raise Exception(f"Failed to download {filename}: {error}") model_paths.append(path) progress(0.2 + 0.6 * (i + 1) / len(futures), desc=f"Downloaded {i+1}/{len(futures)}: {filename}") except Exception as e: raise e return model_paths, temp_dir elif "*" in filename_pattern: # For wildcard patterns, download the entire directory and filter progress(0.1, desc="Downloading repository snapshot...") local_dir = os.path.join(temp_dir, "download") snapshot_download( repo_id=repo_id, subfolder=subfolder or None, local_dir=local_dir, token=hf_token, resume_download=True ) # Find files matching the pattern if subfolder: pattern_dir = os.path.join(local_dir, subfolder) else: pattern_dir = local_dir model_files = glob.glob(os.path.join(pattern_dir, filename_pattern)) if not model_files: raise ValueError(f"No files found matching pattern: {filename_pattern}") # Limit number of files if len(model_files) > 50: raise ValueError(f"Too many files found ({len(model_files)}). Please use a more specific pattern.") return model_files, temp_dir else: # SINGLE FILE SAFETENSORS - separate from shard discovery progress(0.2, desc=f"Downloading {filename_pattern}...") model_path = hf_hub_download( repo_id=repo_id, filename=filename_pattern, subfolder=subfolder or None, cache_dir=temp_dir, token=hf_token, resume_download=True ) return [model_path], temp_dir else: # For non-safetensors formats if "*" in filename_pattern: raise ValueError("Wildcards only supported for safetensors format") progress(0.2, desc=f"Downloading {filename_pattern}...") model_path = hf_hub_download( repo_id=repo_id, filename=filename_pattern, subfolder=subfolder or None, cache_dir=temp_dir, token=hf_token, resume_download=True ) return [model_path], temp_dir elif source_type == "modelscope": if not MODELScope_AVAILABLE: raise ImportError("ModelScope not installed") repo_id = repo_url.strip() if model_format == "safetensors" and "*" in filename_pattern: # For ModelScope, we need to handle sharded files differently # This is a simplified approach - in a real implementation, you might need to list files first raise NotImplementedError("Pattern matching for ModelScope sharded files not fully implemented") else: progress(0.2, desc=f"Downloading {filename_pattern}...") model_path = ms_file_download(model_id=repo_id, file_path=filename_pattern) return [model_path], temp_dir else: raise ValueError("Unknown source") except Exception as e: shutil.rmtree(temp_dir, ignore_errors=True) raise e def upload_to_target(target_type, new_repo_id, output_dir, fp8_format, hf_token=None, modelscope_token=None, private_repo=False): if target_type == "huggingface": api = HfApi(token=hf_token) api.create_repo(repo_id=new_repo_id, private=private_repo, repo_type="model", exist_ok=True) api.upload_folder(repo_id=new_repo_id, folder_path=output_dir, repo_type="model", token=hf_token) return f"https://huggingface.co/{new_repo_id}" elif target_type == "modelscope": api = ModelScopeApi() if modelscope_token: api.login(modelscope_token) api.push_model(model_id=new_repo_id, model_dir=output_dir) return f"https://modelscope.cn/models/{new_repo_id}" else: raise ValueError("Unknown target") def process_and_upload_fp8( source_type, repo_url, filename_pattern, model_format, fp8_format, target_type, new_repo_id, hf_token, modelscope_token, private_repo, progress=gr.Progress() ): if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id): return None, "❌ Invalid repo ID format. Use 'username/model-name'.", "", "" if source_type == "huggingface" and not hf_token: return None, "❌ Hugging Face token required for source.", "", "" if target_type == "huggingface" and not hf_token: return None, "❌ Hugging Face token required for target.", "", "" temp_dir = None output_dir = tempfile.mkdtemp() try: progress(0.05, desc="Downloading model...") model_paths, temp_dir = download_model_files( source_type, repo_url, filename_pattern, model_format, hf_token, progress ) progress(0.8, desc="Converting to FP8...") success, msg, stats, fp8_path, _ = convert_model_to_fp8( model_paths, output_dir, fp8_format, model_format, progress ) if not success: return None, f"❌ Conversion failed: {msg}", "", "" progress(0.9, desc="Uploading...") repo_url_final = upload_to_target( target_type, new_repo_id, output_dir, fp8_format, hf_token, modelscope_token, private_repo ) # Generate README if len(model_paths) == 1: original_filename = os.path.basename(model_paths[0]) else: original_filename = f"{len(model_paths)} sharded files" # Add the pattern if not auto if filename_pattern != "auto": original_filename += f" matching '{filename_pattern}'" fp8_filename = os.path.basename(fp8_path) readme = f"""--- library_name: diffusers tags: - fp8 - safetensors - converted-by-gradio --- # FP8 Model Conversion - **Source**: `{repo_url}` - **Original File(s)**: `{original_filename}` - **Original Format**: `{model_format}` - **FP8 Format**: `{fp8_format.upper()}` - **FP8 File**: `{fp8_filename}` ## Usage ```python from safetensors.torch import load_file import torch # Load FP8 model fp8_state = load_file("{fp8_filename}") # Convert tensors back to float32 for computation (auto-converted by PyTorch) model.load_state_dict(fp8_state) ``` > **Note**: FP8 tensors are automatically converted to float32 when loaded in PyTorch. > Requires PyTorch ≥ 2.1 for FP8 support. ## Statistics - **Total tensors**: {stats['total_tensors']} - **Converted to FP8**: {stats['converted_tensors']} - **Skipped (non-float)**: {stats['skipped_tensors']} """ with open(os.path.join(output_dir, "README.md"), "w") as f: f.write(readme) if target_type == "huggingface": HfApi(token=hf_token).upload_file( path_or_fileobj=os.path.join(output_dir, "README.md"), path_in_repo="README.md", repo_id=new_repo_id, repo_type="model", token=hf_token ) progress(1.0, desc="✅ Done!") # Generate result HTML result_html = f""" ✅ Success! Model uploaded to: {new_repo_id} - FP8 model: `{fp8_filename}` - Converted {stats['converted_tensors']} tensors to {fp8_format.upper()} """ return (gr.HTML(result_html), "✅ FP8 conversion successful!", msg, "") except Exception as e: traceback.print_exc() return None, f"❌ Error: {str(e)}", "", "" finally: if temp_dir: shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(output_dir, ignore_errors=True) with gr.Blocks(title="Fast FP8 Model Converter") as demo: gr.Markdown("# ⚡ Fast FP8 Model Converter") gr.Markdown("Convert model files (safetensors, pth, ckpt) → **FP8**. Supports sharded files with auto-discovery. Simple and fast!") with gr.Row(): with gr.Column(): source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source") repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id") with gr.Row(): model_format = gr.Dropdown( choices=["safetensors", "pth", "pt", "ckpt"], value="safetensors", label="Model Format" ) filename_pattern = gr.Textbox( label="Filename or Pattern", placeholder="auto (detects sharded files) or model-*.safetensors", value="auto" ) with gr.Accordion("FP8 Settings", open=True): fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format") with gr.Accordion("Authentication", open=False): hf_token = gr.Textbox(label="Hugging Face Token", type="password") modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE) with gr.Column(): target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target") new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8") private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False) status_output = gr.Markdown() detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10) recovery_summary = gr.Textbox(label="Additional Info", interactive=False, lines=3) convert_btn = gr.Button("🚀 Convert & Upload", variant="primary") repo_link_output = gr.HTML() convert_btn.click( fn=process_and_upload_fp8, inputs=[ source_type, repo_url, filename_pattern, model_format, fp8_format, target_type, new_repo_id, hf_token, modelscope_token, private_repo ], outputs=[repo_link_output, status_output, detailed_log, recovery_summary], show_progress=True ) gr.Examples( examples=[ [ "huggingface", "https://huggingface.co/stabilityai/sdxl-vae", "auto", "safetensors", "e4m3fn", "huggingface" ], [ "huggingface", "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder", "auto", "safetensors", "e5m2", "huggingface" ], [ "huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "auto", "safetensors", "e5m2", "huggingface" ], [ "huggingface", "https://huggingface.co/stabilityai/stable-diffusion-2-1", "model-*.safetensors", "safetensors", "e5m2", "huggingface" ], [ "huggingface", "https://huggingface.co/CompVis/stable-diffusion-v1-4", "sd-v1-4.ckpt", "ckpt", "e5m2", "huggingface" ] ], inputs=[source_type, repo_url, filename_pattern, model_format, fp8_format, target_type], label="Example Conversions", cache_examples=False ) gr.Markdown(""" ## 📁 Fast FP8 Conversion Tool This tool provides **fast and simple FP8 conversion** for various model formats: ### **Supported Formats:** - **Safetensors**: Modern, secure format. Supports sharded files (e.g., `model-00001-of-00005.safetensors`) - **PTH/PT**: PyTorch checkpoint files - **CKPT**: Checkpoint files (commonly used for stable diffusion models) ### **Shard Support:** - **Unlimited Shards**: Supports any number of sharded files (2, 5, 10, 20+) - **Auto-Detection**: Automatically finds all shards when using "auto" pattern - **Parallel Downloads**: Downloads multiple shards simultaneously (up to 4 at once) - **Memory Efficient**: Processes files efficiently to manage memory ### **Performance Features:** - **Fast Conversion**: Simple dtype conversion without complex recovery strategies - **Batch Processing**: Processes tensors in batches for better performance - **Progress Tracking**: Shows detailed progress for each step ### **How It Works:** 1. **Discovery**: Automatically detects sharded files or uses your specified pattern 2. **Download**: Downloads files in parallel for maximum speed 3. **Conversion**: Converts float tensors to FP8, leaves other types unchanged 4. **Upload**: Uploads the converted model to your target repository ### **Usage Tips:** - Use "auto" pattern to automatically detect all sharded safetensors files - Use `model-*.safetensors` to match specific shard patterns - For single files, just enter the filename (e.g., `model.safetensors`) - FP8 conversion reduces model size by ~4x compared to FP32 - FP8 tensors are automatically converted to float32 when loaded in PyTorch > **Note**: This is a simple conversion tool. For precision recovery options, use the advanced version. """) demo.launch()