new03 / app.py
codemichaeld's picture
Update app.py
9310eed verified
import gradio as gr
import os
import tempfile
import shutil
import re
import json
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, snapshot_download, list_repo_files
from safetensors.torch import load_file, save_file
import torch
import torch.nn.functional as F
import traceback
import glob
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
try:
from modelscope.hub.file_download import model_file_download as ms_file_download
from modelscope.hub.api import HubApi as ModelScopeApi
MODELScope_AVAILABLE = True
except ImportError:
MODELScope_AVAILABLE = False
def load_model_files(model_paths, model_format="safetensors", progress_callback=None):
"""
Load model weights from one or more files, supporting sharded safetensors and other formats.
"""
state_dict = {}
if model_format == "safetensors":
# Handle sharded safetensors files
for i, path in enumerate(model_paths):
if progress_callback:
progress_callback(f"Loading shard {i+1}/{len(model_paths)}: {os.path.basename(path)}")
part_dict = load_file(path)
state_dict.update(part_dict)
elif model_format in ["pth", "pt"]:
# PyTorch checkpoint files
for i, path in enumerate(model_paths):
if progress_callback:
progress_callback(f"Loading checkpoint {i+1}/{len(model_paths)}: {os.path.basename(path)}")
checkpoint = torch.load(path, map_location="cpu")
if isinstance(checkpoint, dict):
# Try to extract state dict from checkpoint
if "state_dict" in checkpoint:
state_dict.update(checkpoint["state_dict"])
elif "model_state_dict" in checkpoint:
state_dict.update(checkpoint["model_state_dict"])
elif "model" in checkpoint and isinstance(checkpoint["model"], dict):
state_dict.update(checkpoint["model"])
else:
# Assume the checkpoint itself is the state dict
state_dict.update(checkpoint)
elif model_format == "ckpt":
# Checkpoint files (similar to pth)
for i, path in enumerate(model_paths):
if progress_callback:
progress_callback(f"Loading checkpoint {i+1}/{len(model_paths)}: {os.path.basename(path)}")
checkpoint = torch.load(path, map_location="cpu")
if isinstance(checkpoint, dict):
if "state_dict" in checkpoint:
state_dict.update(checkpoint["state_dict"])
elif "model_state_dict" in checkpoint:
state_dict.update(checkpoint["model_state_dict"])
elif "model" in checkpoint and isinstance(checkpoint["model"], dict):
state_dict.update(checkpoint["model"])
else:
state_dict.update(checkpoint)
return state_dict
def read_model_metadata(model_paths, model_format="safetensors"):
"""Read metadata from model files."""
metadata = {}
if model_format == "safetensors":
# Read metadata from the first safetensors file
if model_paths:
with open(model_paths[0], 'rb') as f:
header_size = int.from_bytes(f.read(8), 'little')
header_json = f.read(header_size).decode('utf-8')
header = json.loads(header_json)
metadata = header.get('__metadata__', {})
elif model_format in ["pth", "pt", "ckpt"]:
# Try to extract metadata from checkpoint files
if model_paths:
checkpoint = torch.load(model_paths[0], map_location="cpu")
if isinstance(checkpoint, dict):
# Look for common metadata keys
for key in ["hyperparameters", "args", "config", "metadata"]:
if key in checkpoint:
metadata[key] = checkpoint[key]
return metadata
def extract_base_name_from_sharded_files(model_paths):
"""Extract a common base name from sharded files."""
if not model_paths:
return "model"
if len(model_paths) == 1:
# Single file case
base_name = os.path.splitext(os.path.basename(model_paths[0]))[0]
# Remove common suffixes
for suffix in ["-fp8", "-fp16", "-bf16", "-32", "-16"]:
if base_name.endswith(suffix):
base_name = base_name[:-len(suffix)]
return base_name
# Multiple files case - find common prefix
base_names = [os.path.splitext(os.path.basename(p))[0] for p in model_paths]
# Handle Hugging Face pattern: model-00001-of-00002.safetensors
# Extract the part before the shard numbering
if all("-of-" in name for name in base_names):
# All files follow the "model-XXXXX-of-YYYYY" pattern
common_parts = []
for name in base_names:
# Split at the shard numbering
parts = name.split("-")
if len(parts) >= 3 and parts[-2].isdigit() and parts[-1].startswith("of"):
# Remove the last two parts (shard number and total)
common_part = "-".join(parts[:-2])
common_parts.append(common_part)
else:
common_parts.append(name)
# Use the most common base name
from collections import Counter
base_name = Counter(common_parts).most_common(1)[0][0]
return base_name
# Fallback: find common prefix
common_prefix = ""
for chars in zip(*base_names):
if len(set(chars)) == 1:
common_prefix += chars[0]
else:
break
# Clean up the common prefix
base_name = re.sub(r'[-_]+$', '', common_prefix)
if not base_name:
base_name = "model"
return base_name
def convert_model_to_fp8(model_paths, output_dir, fp8_format,
model_format="safetensors", progress=gr.Progress()):
"""Simple and fast FP8 conversion without recovery strategies."""
progress(0.05, desc=f"Starting FP8 conversion for {model_format}...")
try:
metadata = read_model_metadata(model_paths, model_format)
progress(0.1, desc="Loaded metadata.")
# Load model with progress tracking
state_dict = load_model_files(
model_paths,
model_format,
progress_callback=lambda msg: progress(0.15, desc=msg)
)
progress(0.25, desc=f"Loaded {len(model_paths)} model files with {len(state_dict)} tensors.")
# Setup FP8 format
fp8_dtype = torch.float8_e5m2 if fp8_format == "e5m2" else torch.float8_e4m3fn
# Initialize outputs
sd_fp8 = {}
conversion_stats = {
"total_tensors": len(state_dict),
"converted_tensors": 0,
"skipped_tensors": 0,
"skipped_reasons": []
}
# Process each tensor
total = len(state_dict)
for i, key in enumerate(state_dict):
if i % 100 == 0: # Update progress every 100 tensors for speed
progress(0.3 + 0.6 * (i / total), desc=f"Converting {i}/{total} tensors...")
weight = state_dict[key]
# Convert only float tensors to FP8
if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
fp8_weight = weight.to(fp8_dtype)
sd_fp8[key] = fp8_weight
conversion_stats["converted_tensors"] += 1
else:
# Keep non-float tensors as-is (e.g., ints, bools)
sd_fp8[key] = weight
conversion_stats["skipped_tensors"] += 1
conversion_stats["skipped_reasons"].append(f"{key}: {weight.dtype}")
# Extract base name for output files
base_name = extract_base_name_from_sharded_files(model_paths)
# Save FP8 model
fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
save_file(sd_fp8, fp8_path, metadata={
"format": model_format,
"fp8_format": fp8_format,
"original_files": str(len(model_paths)),
"conversion_stats": json.dumps(conversion_stats),
**metadata
})
progress(0.95, desc="Saved FP8 file.")
# Generate stats message
stats_msg = f"βœ… FP8 ({fp8_format}) conversion complete!\n"
stats_msg += f"- Total tensors: {conversion_stats['total_tensors']}\n"
stats_msg += f"- Converted to FP8: {conversion_stats['converted_tensors']}\n"
stats_msg += f"- Skipped (non-float): {conversion_stats['skipped_tensors']}\n"
stats_msg += f"- Output file: {os.path.basename(fp8_path)}\n"
if conversion_stats["skipped_tensors"] > 0:
stats_msg += "\n⚠️ Some tensors were skipped (non-float types):\n"
for i, reason in enumerate(conversion_stats["skipped_reasons"][:5]): # Show first 5
stats_msg += f" - {reason}\n"
if len(conversion_stats["skipped_reasons"]) > 5:
stats_msg += f" - ... and {len(conversion_stats['skipped_reasons']) - 5} more\n"
progress(1.0, desc="βœ… FP8 conversion complete!")
return True, stats_msg, conversion_stats, fp8_path, None
except Exception as e:
traceback.print_exc()
return False, str(e), None, None, None
def parse_hf_url(url):
url = url.strip().rstrip("/")
if not url.startswith("https://huggingface.co/"):
raise ValueError("URL must start with https://huggingface.co/")
path = url.replace("https://huggingface.co/", "")
parts = path.split("/")
if len(parts) < 2:
raise ValueError("Invalid repo format")
repo_id = "/".join(parts[:2])
subfolder = ""
if len(parts) > 3 and parts[2] == "tree":
subfolder = "/".join(parts[4:]) if len(parts) > 4 else ""
elif len(parts) > 2:
subfolder = "/".join(parts[2:])
return repo_id, subfolder
def download_single_file(args):
"""Helper function for parallel downloads."""
repo_id, filename, subfolder, cache_dir, token = args
try:
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder,
cache_dir=cache_dir,
token=token,
resume_download=True
)
return path, None
except Exception as e:
return None, str(e)
def find_sharded_safetensors_files(repo_id, subfolder=None, hf_token=None, max_shards=50):
"""Find all sharded safetensors files in a repository."""
try:
# List all files in the repository
repo_files = list_repo_files(repo_id, repo_type="model", token=hf_token)
# Filter for safetensors files in the subfolder
if subfolder:
pattern = f"{subfolder}/" if not subfolder.endswith("/") else subfolder
safetensors_files = [f for f in repo_files if f.endswith('.safetensors') and f.startswith(pattern)]
# Remove subfolder prefix
safetensors_files = [f[len(pattern):] for f in safetensors_files if len(f) > len(pattern)]
else:
safetensors_files = [f for f in repo_files if f.endswith('.safetensors')]
# Check if files follow sharding pattern
sharded_files = []
single_files = []
for f in safetensors_files:
# Check for sharding pattern: model-XXXXX-of-YYYYY.safetensors
match = re.search(r'-\d{5}-of-\d{5}\.safetensors$', f)
if match:
sharded_files.append(f)
else:
single_files.append(f)
# If we have sharded files, return them sorted by shard number
if sharded_files:
# Sort by shard number for consistent ordering
def extract_shard_num(filename):
match = re.search(r'-(\d{5})-of-\d{5}\.safetensors$', filename)
return int(match.group(1)) if match else 0
sharded_files.sort(key=extract_shard_num)
# Limit number of shards to prevent accidental downloads of huge models
if len(sharded_files) > max_shards:
raise ValueError(f"Too many shards found ({len(sharded_files)}). Maximum allowed is {max_shards}. "
f"Please specify a more specific pattern.")
return sharded_files
elif single_files:
# Return single files (non-sharded)
return single_files
else:
return []
except Exception as e:
print(f"Error listing repository files: {e}")
return []
def download_model_files(source_type, repo_url, filename_pattern, model_format, hf_token=None, progress=gr.Progress()):
temp_dir = tempfile.mkdtemp()
try:
if source_type == "huggingface":
repo_id, subfolder = parse_hf_url(repo_url)
if model_format == "safetensors":
# Handle different patterns for safetensors
if filename_pattern == "auto" or filename_pattern == "":
# Auto-detect sharded files
progress(0.1, desc="Discovering model files...")
found_files = find_sharded_safetensors_files(repo_id, subfolder, hf_token)
if not found_files:
raise ValueError("No safetensors files found in repository")
progress(0.2, desc=f"Found {len(found_files)} shard(s). Downloading...")
# Download files in parallel for better performance
model_paths = []
download_args = [
(repo_id, filename, subfolder, temp_dir, hf_token)
for filename in found_files
]
with ThreadPoolExecutor(max_workers=4) as executor:
futures = {executor.submit(download_single_file, args): args[1] for args in download_args}
for i, future in enumerate(as_completed(futures)):
filename = futures[future]
try:
path, error = future.result()
if error:
raise Exception(f"Failed to download {filename}: {error}")
model_paths.append(path)
progress(0.2 + 0.6 * (i + 1) / len(futures),
desc=f"Downloaded {i+1}/{len(futures)}: {filename}")
except Exception as e:
raise e
return model_paths, temp_dir
elif "*" in filename_pattern:
# For wildcard patterns, download the entire directory and filter
progress(0.1, desc="Downloading repository snapshot...")
local_dir = os.path.join(temp_dir, "download")
snapshot_download(
repo_id=repo_id,
subfolder=subfolder or None,
local_dir=local_dir,
token=hf_token,
resume_download=True
)
# Find files matching the pattern
if subfolder:
pattern_dir = os.path.join(local_dir, subfolder)
else:
pattern_dir = local_dir
model_files = glob.glob(os.path.join(pattern_dir, filename_pattern))
if not model_files:
raise ValueError(f"No files found matching pattern: {filename_pattern}")
# Limit number of files
if len(model_files) > 50:
raise ValueError(f"Too many files found ({len(model_files)}). Please use a more specific pattern.")
return model_files, temp_dir
else:
# SINGLE FILE SAFETENSORS - separate from shard discovery
progress(0.2, desc=f"Downloading {filename_pattern}...")
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename_pattern,
subfolder=subfolder or None,
cache_dir=temp_dir,
token=hf_token,
resume_download=True
)
return [model_path], temp_dir
else:
# For non-safetensors formats
if "*" in filename_pattern:
raise ValueError("Wildcards only supported for safetensors format")
progress(0.2, desc=f"Downloading {filename_pattern}...")
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename_pattern,
subfolder=subfolder or None,
cache_dir=temp_dir,
token=hf_token,
resume_download=True
)
return [model_path], temp_dir
elif source_type == "modelscope":
if not MODELScope_AVAILABLE:
raise ImportError("ModelScope not installed")
repo_id = repo_url.strip()
if model_format == "safetensors" and "*" in filename_pattern:
# For ModelScope, we need to handle sharded files differently
# This is a simplified approach - in a real implementation, you might need to list files first
raise NotImplementedError("Pattern matching for ModelScope sharded files not fully implemented")
else:
progress(0.2, desc=f"Downloading {filename_pattern}...")
model_path = ms_file_download(model_id=repo_id, file_path=filename_pattern)
return [model_path], temp_dir
else:
raise ValueError("Unknown source")
except Exception as e:
shutil.rmtree(temp_dir, ignore_errors=True)
raise e
def upload_to_target(target_type, new_repo_id, output_dir, fp8_format, hf_token=None, modelscope_token=None, private_repo=False):
if target_type == "huggingface":
api = HfApi(token=hf_token)
api.create_repo(repo_id=new_repo_id, private=private_repo, repo_type="model", exist_ok=True)
api.upload_folder(repo_id=new_repo_id, folder_path=output_dir, repo_type="model", token=hf_token)
return f"https://huggingface.co/{new_repo_id}"
elif target_type == "modelscope":
api = ModelScopeApi()
if modelscope_token:
api.login(modelscope_token)
api.push_model(model_id=new_repo_id, model_dir=output_dir)
return f"https://modelscope.cn/models/{new_repo_id}"
else:
raise ValueError("Unknown target")
def process_and_upload_fp8(
source_type,
repo_url,
filename_pattern,
model_format,
fp8_format,
target_type,
new_repo_id,
hf_token,
modelscope_token,
private_repo,
progress=gr.Progress()
):
if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id):
return None, "❌ Invalid repo ID format. Use 'username/model-name'.", "", ""
if source_type == "huggingface" and not hf_token:
return None, "❌ Hugging Face token required for source.", "", ""
if target_type == "huggingface" and not hf_token:
return None, "❌ Hugging Face token required for target.", "", ""
temp_dir = None
output_dir = tempfile.mkdtemp()
try:
progress(0.05, desc="Downloading model...")
model_paths, temp_dir = download_model_files(
source_type, repo_url, filename_pattern, model_format, hf_token, progress
)
progress(0.8, desc="Converting to FP8...")
success, msg, stats, fp8_path, _ = convert_model_to_fp8(
model_paths, output_dir, fp8_format, model_format, progress
)
if not success:
return None, f"❌ Conversion failed: {msg}", "", ""
progress(0.9, desc="Uploading...")
repo_url_final = upload_to_target(
target_type, new_repo_id, output_dir, fp8_format, hf_token, modelscope_token, private_repo
)
# Generate README
if len(model_paths) == 1:
original_filename = os.path.basename(model_paths[0])
else:
original_filename = f"{len(model_paths)} sharded files"
# Add the pattern if not auto
if filename_pattern != "auto":
original_filename += f" matching '{filename_pattern}'"
fp8_filename = os.path.basename(fp8_path)
readme = f"""---
library_name: diffusers
tags:
- fp8
- safetensors
- converted-by-gradio
---
# FP8 Model Conversion
- **Source**: `{repo_url}`
- **Original File(s)**: `{original_filename}`
- **Original Format**: `{model_format}`
- **FP8 Format**: `{fp8_format.upper()}`
- **FP8 File**: `{fp8_filename}`
## Usage
```python
from safetensors.torch import load_file
import torch
# Load FP8 model
fp8_state = load_file("{fp8_filename}")
# Convert tensors back to float32 for computation (auto-converted by PyTorch)
model.load_state_dict(fp8_state)
```
> **Note**: FP8 tensors are automatically converted to float32 when loaded in PyTorch.
> Requires PyTorch β‰₯ 2.1 for FP8 support.
## Statistics
- **Total tensors**: {stats['total_tensors']}
- **Converted to FP8**: {stats['converted_tensors']}
- **Skipped (non-float)**: {stats['skipped_tensors']}
"""
with open(os.path.join(output_dir, "README.md"), "w") as f:
f.write(readme)
if target_type == "huggingface":
HfApi(token=hf_token).upload_file(
path_or_fileobj=os.path.join(output_dir, "README.md"),
path_in_repo="README.md",
repo_id=new_repo_id,
repo_type="model",
token=hf_token
)
progress(1.0, desc="βœ… Done!")
# Generate result HTML
result_html = f"""
βœ… Success!
Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
- FP8 model: `{fp8_filename}`
- Converted {stats['converted_tensors']} tensors to {fp8_format.upper()}
"""
return (gr.HTML(result_html),
"βœ… FP8 conversion successful!",
msg,
"")
except Exception as e:
traceback.print_exc()
return None, f"❌ Error: {str(e)}", "", ""
finally:
if temp_dir:
shutil.rmtree(temp_dir, ignore_errors=True)
shutil.rmtree(output_dir, ignore_errors=True)
with gr.Blocks(title="Fast FP8 Model Converter") as demo:
gr.Markdown("# ⚑ Fast FP8 Model Converter")
gr.Markdown("Convert model files (safetensors, pth, ckpt) β†’ **FP8**. Supports sharded files with auto-discovery. Simple and fast!")
with gr.Row():
with gr.Column():
source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source")
repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
with gr.Row():
model_format = gr.Dropdown(
choices=["safetensors", "pth", "pt", "ckpt"],
value="safetensors",
label="Model Format"
)
filename_pattern = gr.Textbox(
label="Filename or Pattern",
placeholder="auto (detects sharded files) or model-*.safetensors",
value="auto"
)
with gr.Accordion("FP8 Settings", open=True):
fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
with gr.Accordion("Authentication", open=False):
hf_token = gr.Textbox(label="Hugging Face Token", type="password")
modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
with gr.Column():
target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target")
new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8")
private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False)
status_output = gr.Markdown()
detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10)
recovery_summary = gr.Textbox(label="Additional Info", interactive=False, lines=3)
convert_btn = gr.Button("πŸš€ Convert & Upload", variant="primary")
repo_link_output = gr.HTML()
convert_btn.click(
fn=process_and_upload_fp8,
inputs=[
source_type,
repo_url,
filename_pattern,
model_format,
fp8_format,
target_type,
new_repo_id,
hf_token,
modelscope_token,
private_repo
],
outputs=[repo_link_output, status_output, detailed_log, recovery_summary],
show_progress=True
)
gr.Examples(
examples=[
[
"huggingface",
"https://huggingface.co/stabilityai/sdxl-vae",
"auto",
"safetensors",
"e4m3fn",
"huggingface"
],
[
"huggingface",
"https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
"auto",
"safetensors",
"e5m2",
"huggingface"
],
[
"huggingface",
"https://huggingface.co/Yabo/FramePainter/tree/main",
"auto",
"safetensors",
"e5m2",
"huggingface"
],
[
"huggingface",
"https://huggingface.co/stabilityai/stable-diffusion-2-1",
"model-*.safetensors",
"safetensors",
"e5m2",
"huggingface"
],
[
"huggingface",
"https://huggingface.co/CompVis/stable-diffusion-v1-4",
"sd-v1-4.ckpt",
"ckpt",
"e5m2",
"huggingface"
]
],
inputs=[source_type, repo_url, filename_pattern, model_format, fp8_format, target_type],
label="Example Conversions",
cache_examples=False
)
gr.Markdown("""
## πŸ“ Fast FP8 Conversion Tool
This tool provides **fast and simple FP8 conversion** for various model formats:
### **Supported Formats:**
- **Safetensors**: Modern, secure format. Supports sharded files (e.g., `model-00001-of-00005.safetensors`)
- **PTH/PT**: PyTorch checkpoint files
- **CKPT**: Checkpoint files (commonly used for stable diffusion models)
### **Shard Support:**
- **Unlimited Shards**: Supports any number of sharded files (2, 5, 10, 20+)
- **Auto-Detection**: Automatically finds all shards when using "auto" pattern
- **Parallel Downloads**: Downloads multiple shards simultaneously (up to 4 at once)
- **Memory Efficient**: Processes files efficiently to manage memory
### **Performance Features:**
- **Fast Conversion**: Simple dtype conversion without complex recovery strategies
- **Batch Processing**: Processes tensors in batches for better performance
- **Progress Tracking**: Shows detailed progress for each step
### **How It Works:**
1. **Discovery**: Automatically detects sharded files or uses your specified pattern
2. **Download**: Downloads files in parallel for maximum speed
3. **Conversion**: Converts float tensors to FP8, leaves other types unchanged
4. **Upload**: Uploads the converted model to your target repository
### **Usage Tips:**
- Use "auto" pattern to automatically detect all sharded safetensors files
- Use `model-*.safetensors` to match specific shard patterns
- For single files, just enter the filename (e.g., `model.safetensors`)
- FP8 conversion reduces model size by ~4x compared to FP32
- FP8 tensors are automatically converted to float32 when loaded in PyTorch
> **Note**: This is a simple conversion tool. For precision recovery options, use the advanced version.
""")
demo.launch()