import gradio as gr import os import tempfile import shutil import re import json import datetime from pathlib import Path from huggingface_hub import HfApi, hf_hub_download from safetensors.torch import load_file, save_file import torch # Optional ModelScope integration try: from modelscope.hub.snapshot_download import snapshot_download as ms_snapshot_download from modelscope.hub.file_download import model_file_download as ms_file_download from modelscope.hub.api import HubApi as ModelScopeApi MODELScope_AVAILABLE = True except ImportError: MODELScope_AVAILABLE = False # --- Conversion Function: Safetensors → FP8 Safetensors --- def convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress=gr.Progress()): progress(0.1, desc="Starting FP8 conversion...") try: def read_safetensors_metadata(path): with open(path, 'rb') as f: header_size = int.from_bytes(f.read(8), 'little') header_json = f.read(header_size).decode('utf-8') header = json.loads(header_json) return header.get('__metadata__', {}) metadata = read_safetensors_metadata(safetensors_path) progress(0.3, desc="Loaded model metadata.") state_dict = load_file(safetensors_path) progress(0.5, desc="Loaded model weights.") if fp8_format == "e5m2": fp8_dtype = torch.float8_e5m2 else: fp8_dtype = torch.float8_e4m3fn sd_pruned = {} total = len(state_dict) for i, key in enumerate(state_dict): progress(0.5 + 0.4 * (i / total), desc=f"Converting tensor {i+1}/{total} to FP8 ({fp8_format})...") if state_dict[key].dtype in [torch.float16, torch.float32, torch.bfloat16]: sd_pruned[key] = state_dict[key].to(fp8_dtype) else: sd_pruned[key] = state_dict[key] base_name = os.path.splitext(os.path.basename(safetensors_path))[0] output_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors") save_file(sd_pruned, output_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata}) progress(0.9, desc="Saved FP8 safetensors file.") progress(1.0, desc="FP8 conversion complete!") return True, f"Model successfully pruned to FP8 ({fp8_format})." except Exception as e: return False, str(e) # --- Source download helper --- def download_safetensors_file( source_type, repo_url, filename, hf_token=None, modelscope_token=None, progress=gr.Progress() ): temp_dir = tempfile.mkdtemp() try: if source_type == "huggingface": clean_url = repo_url.strip().rstrip("/") if "huggingface.co" not in clean_url: raise ValueError("Invalid Hugging Face URL") src_repo_id = clean_url.replace("https://huggingface.co/", "") safetensors_path = hf_hub_download( repo_id=src_repo_id, filename=filename, cache_dir=temp_dir, token=hf_token ) elif source_type == "modelscope": if not MODELScope_AVAILABLE: raise ImportError("ModelScope not installed. Install with: pip install modelscope") clean_url = repo_url.strip().rstrip("/") if "modelscope.cn" in clean_url: src_repo_id = "/".join(clean_url.split("/")[-2:]) else: src_repo_id = repo_url.strip() if modelscope_token: os.environ["MODELSCOPE_CACHE"] = temp_dir safetensors_path = ms_file_download( model_id=src_repo_id, file_path=filename, token=modelscope_token ) else: safetensors_path = ms_file_download( model_id=src_repo_id, file_path=filename ) else: raise ValueError("Unknown source type") return safetensors_path, temp_dir except Exception as e: shutil.rmtree(temp_dir, ignore_errors=True) raise e # --- Upload helper --- def upload_to_target( target_type, new_repo_id, output_dir, fp8_format, hf_token=None, modelscope_token=None, private_repo=False, progress=gr.Progress() ): if target_type == "huggingface": if not hf_token: raise ValueError("Hugging Face token required") api = HfApi(token=hf_token) api.create_repo( repo_id=new_repo_id, private=private_repo, repo_type="model", exist_ok=True ) api.upload_folder( repo_id=new_repo_id, folder_path=output_dir, repo_type="model", token=hf_token, commit_message=f"Upload FP8 ({fp8_format}) model" ) return f"https://huggingface.co/{new_repo_id}" elif target_type == "modelscope": if not MODELScope_AVAILABLE: raise ImportError("ModelScope not installed") api = ModelScopeApi() if modelscope_token: api.login(modelscope_token) # ModelScope requires model_type and license api.push_model( model_id=new_repo_id, model_dir=output_dir, commit_message=f"Upload FP8 ({fp8_format}) model" ) return f"https://modelscope.cn/models/{new_repo_id}" else: raise ValueError("Unknown target type") # --- Main Processing Function --- def process_and_upload_fp8( source_type, repo_url, safetensors_filename, fp8_format, target_type, new_repo_id, hf_token, modelscope_token, private_repo, progress=gr.Progress() ): required_fields = [repo_url, safetensors_filename, new_repo_id] if source_type == "huggingface": required_fields.append(hf_token) if target_type == "huggingface": required_fields.append(hf_token) if target_type == "modelscope" and modelscope_token: required_fields.append(modelscope_token) if not all(required_fields): return None, "❌ Error: Please fill in all required fields.", "" if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id): return None, "❌ Invalid repository ID format. Use 'username/model-name'.", "" temp_dir = None output_dir = tempfile.mkdtemp() try: # Authenticate & download progress(0.05, desc="Authenticating and downloading...") safetensors_path, temp_dir = download_safetensors_file( source_type=source_type, repo_url=repo_url, filename=safetensors_filename, hf_token=hf_token, modelscope_token=modelscope_token, progress=progress ) progress(0.25, desc="Download complete.") # Convert success, msg = convert_safetensors_to_fp8(safetensors_path, output_dir, fp8_format, progress) if not success: return None, f"❌ Conversion failed: {msg}", "" # Upload progress(0.92, desc="Uploading model...") repo_url_final = upload_to_target( target_type=target_type, new_repo_id=new_repo_id, output_dir=output_dir, fp8_format=fp8_format, hf_token=hf_token, modelscope_token=modelscope_token, private_repo=private_repo, progress=progress ) # README base_name = os.path.splitext(safetensors_filename)[0] fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors" readme = f"""--- library_name: diffusers tags: - fp8 - safetensors - pruned - diffusion - converted-by-gradio - fp8-{fp8_format} --- # FP8 Pruned Model ({fp8_format.upper()}) Converted from: `{repo_url}` File: `{safetensors_filename}` → `{fp8_filename}` Quantization: **FP8 ({fp8_format.upper()})** Converted on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} > ⚠️ Requires PyTorch ≥ 2.1 and compatible hardware for FP8 acceleration. """ readme_path = os.path.join(output_dir, "README.md") with open(readme_path, "w") as f: f.write(readme) # Re-upload README if needed (for ModelScope, already included; for HF, upload separately) if target_type == "huggingface": HfApi(token=hf_token).upload_file( path_or_fileobj=readme_path, path_in_repo="README.md", repo_id=new_repo_id, repo_type="model", token=hf_token ) progress(1.0, desc="✅ Done!") result_html = f""" ✅ Success! Your FP8 model is uploaded to: {new_repo_id} Source: {source_type.title()} → Target: {target_type.title()} """ return gr.HTML(result_html), "✅ FP8 conversion and upload successful!", "" except Exception as e: return None, f"❌ Error: {str(e)}", "" finally: if temp_dir: shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(output_dir, ignore_errors=True) # --- Gradio UI --- with gr.Blocks(title="Safetensors → FP8 Pruner (HF + ModelScope)") as demo: gr.Markdown("# 🔄 Safetensors to FP8 Pruner") gr.Markdown("Convert `.safetensors` models to **FP8** and upload to **Hugging Face** or **ModelScope**.") with gr.Row(): with gr.Column(): source_type = gr.Radio( choices=["huggingface", "modelscope"], value="huggingface", label="Source Platform" ) repo_url = gr.Textbox( label="Source Repository URL", placeholder="e.g., https://huggingface.co/Yabo/FramePainter OR your-modelscope-id", info="Hugging Face URL or ModelScope model ID" ) safetensors_filename = gr.Textbox( label="Safetensors Filename", placeholder="unet_diffusion_pytorch_model.safetensors" ) fp8_format = gr.Radio( choices=["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format", info="E5M2: wider range; E4M3FN: better near-zero precision" ) hf_token = gr.Textbox( label="Hugging Face Token (if using HF)", type="password" ) modelscope_token = gr.Textbox( label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE ) with gr.Column(): target_type = gr.Radio( choices=["huggingface", "modelscope"], value="huggingface", label="Target Platform" ) new_repo_id = gr.Textbox( label="New Repository ID", placeholder="your-username/my-model-fp8" ) private_repo = gr.Checkbox(label="Make Private (HF only)", value=False) convert_btn = gr.Button("🚀 Convert & Upload", variant="primary") with gr.Row(): status_output = gr.Markdown() repo_link_output = gr.HTML() convert_btn.click( fn=process_and_upload_fp8, inputs=[ source_type, repo_url, safetensors_filename, fp8_format, target_type, new_repo_id, hf_token, modelscope_token, private_repo ], outputs=[repo_link_output, status_output], show_progress=True ) gr.Examples( examples=[ ["huggingface", "https://huggingface.co/Yabo/FramePainter", "unet_diffusion_pytorch_model.safetensors", "e5m2", "huggingface"] ], inputs=[source_type, repo_url, safetensors_filename, fp8_format, target_type] ) demo.launch()