import gradio as gr import os import tempfile import shutil import re import json import datetime from pathlib import Path from huggingface_hub import HfApi, hf_hub_download from safetensors.torch import load_file, save_file import torch import torch.nn.functional as F import traceback import math try: from modelscope.hub.file_download import model_file_download as ms_file_download from modelscope.hub.api import HubApi as ModelScopeApi MODELScope_AVAILABLE = True except ImportError: MODELScope_AVAILABLE = False def low_rank_decomposition(weight, rank=64): """Handle both 2D and 4D tensors for LoRA decomposition.""" original_shape = weight.shape original_dtype = weight.dtype try: # Handle 2D tensors (linear layers, attention) if weight.ndim == 2: U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False) if rank > len(S): rank = len(S) // 2 # Use half the available rank if requested rank is too high U = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) Vh = torch.diag(torch.sqrt(S[:rank])) @ Vh[:rank, :] return U.contiguous(), Vh.contiguous() # Handle 4D tensors (convolutional layers) elif weight.ndim == 4: # Strategy 1: Reshape to 2D and decompose out_ch, in_ch, kH, kW = weight.shape # For small conv kernels, use spatial decomposition if kH * kW <= 9: # 3x3 kernel or smaller weight_2d = weight.permute(0, 2, 3, 1).reshape(out_ch * kH * kW, in_ch) U, S, Vh = torch.linalg.svd(weight_2d.float(), full_matrices=False) if rank > len(S): rank = max(8, len(S) // 2) U = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) Vh = torch.diag(torch.sqrt(S[:rank])) @ Vh[:rank, :] # Reshape back to convolutional format U = U.view(out_ch, kH, kW, rank).permute(0, 3, 1, 2).contiguous() Vh = Vh.view(rank, in_ch, 1, 1).contiguous() return U, Vh # For larger kernels, use channel-wise decomposition else: weight_2d = weight.view(out_ch, -1) U, S, Vh = torch.linalg.svd(weight_2d.float(), full_matrices=False) if rank > len(S): rank = max(8, len(S) // 2) U = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) Vh = torch.diag(torch.sqrt(S[:rank])) @ Vh[:rank, :] U = U.view(out_ch, rank, 1, 1).contiguous() Vh = Vh.view(rank, in_ch, kH, kW).contiguous() return U, Vh # Handle 3D tensors (rare, but sometimes in attention mechanisms) elif weight.ndim == 3: out_ch, mid_ch, in_ch = weight.shape weight_2d = weight.reshape(out_ch * mid_ch, in_ch) U, S, Vh = torch.linalg.svd(weight_2d.float(), full_matrices=False) if rank > len(S): rank = max(8, len(S) // 2) U = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) Vh = torch.diag(torch.sqrt(S[:rank])) @ Vh[:rank, :] U = U.view(out_ch, mid_ch, rank).contiguous() Vh = Vh.view(rank, in_ch).contiguous() return U, Vh except Exception as e: print(f"Decomposition error for tensor with shape {original_shape}: {str(e)}") traceback.print_exc() return None, None def should_apply_lora(key, weight, architecture, lora_rank): """Determine if LoRA should be applied to a specific weight based on architecture selection.""" # Skip bias terms, batchnorm, and very small tensors if 'bias' in key or 'norm' in key.lower() or 'bn' in key.lower(): return False # Skip very small tensors if weight.numel() < 100: return False # Architecture-specific rules lower_key = key.lower() if architecture == "text_encoder": # Text encoder: focus on embeddings and attention layers return ('emb' in lower_key or 'embed' in lower_key or 'attn' in lower_key or 'qkv' in lower_key or 'mlp' in lower_key) elif architecture == "unet_transformer": # UNet transformers: focus on attention blocks return ('attn' in lower_key or 'transformer' in lower_key or 'qkv' in lower_key or 'to_out' in lower_key) elif architecture == "unet_conv": # UNet convolutional layers return ('conv' in lower_key or 'resnet' in lower_key or 'downsample' in lower_key or 'upsample' in lower_key) elif architecture == "vae": # VAE components return ('encoder' in lower_key or 'decoder' in lower_key or 'conv' in lower_key or 'post_quant' in lower_key) elif architecture == "all": # Apply to all eligible tensors return True elif architecture == "auto": # Auto-detect based on tensor properties if weight.ndim == 2 and min(weight.shape) > lora_rank: return True if weight.ndim == 4 and (weight.shape[0] > lora_rank or weight.shape[1] > lora_rank): return True return False return False def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=64, architecture="auto", progress=gr.Progress()): progress(0.1, desc="Starting FP8 conversion with LoRA extraction...") try: def read_safetensors_metadata(path): with open(path, 'rb') as f: header_size = int.from_bytes(f.read(8), 'little') header_json = f.read(header_size).decode('utf-8') header = json.loads(header_json) return header.get('__metadata__', {}) metadata = read_safetensors_metadata(safetensors_path) progress(0.2, desc="Loaded metadata.") state_dict = load_file(safetensors_path) progress(0.4, desc="Loaded weights.") # Architecture analysis architecture_stats = { 'text_encoder': 0, 'unet_transformer': 0, 'unet_conv': 0, 'vae': 0, 'other': 0 } for key in state_dict: lower_key = key.lower() if 'text' in lower_key or 'emb' in lower_key: architecture_stats['text_encoder'] += 1 elif 'attn' in lower_key or 'transformer' in lower_key: architecture_stats['unet_transformer'] += 1 elif 'conv' in lower_key or 'resnet' in lower_key: architecture_stats['unet_conv'] += 1 elif 'vae' in lower_key or 'encoder' in lower_key or 'decoder' in lower_key: architecture_stats['vae'] += 1 else: architecture_stats['other'] += 1 print("Architecture analysis:") for arch, count in architecture_stats.items(): print(f"- {arch}: {count} layers") if fp8_format == "e5m2": fp8_dtype = torch.float8_e5m2 else: fp8_dtype = torch.float8_e4m3fn sd_fp8 = {} lora_weights = {} lora_stats = { 'total_layers': len(state_dict), 'layers_analyzed': 0, 'layers_eligible': 0, 'layers_processed': 0, 'layers_skipped': [], 'architecture_distro': architecture_stats } total = len(state_dict) lora_keys = [] for i, key in enumerate(state_dict): progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}: {key.split('.')[-1]}") weight = state_dict[key] lora_stats['layers_analyzed'] += 1 if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]: fp8_weight = weight.to(fp8_dtype) sd_fp8[key] = fp8_weight # Determine if we should apply LoRA eligible_for_lora = should_apply_lora(key, weight, architecture, lora_rank) if eligible_for_lora: lora_stats['layers_eligible'] += 1 try: # Adjust rank based on tensor size actual_rank = lora_rank if weight.ndim == 2: actual_rank = min(lora_rank, min(weight.shape) // 2) elif weight.ndim == 4: actual_rank = min(lora_rank, max(weight.shape[0], weight.shape[1]) // 4) if actual_rank < 4: # Minimum rank threshold lora_stats['layers_skipped'].append(f"{key}: rank too small ({actual_rank})") continue U, V = low_rank_decomposition(weight, rank=actual_rank) if U is not None and V is not None: lora_weights[f"lora_A.{key}"] = U.to(torch.float16) lora_weights[f"lora_B.{key}"] = V.to(torch.float16) lora_keys.append(key) lora_stats['layers_processed'] += 1 else: lora_stats['layers_skipped'].append(f"{key}: decomposition returned None") except Exception as e: error_msg = f"{key}: {str(e)}" lora_stats['layers_skipped'].append(error_msg) print(f"LoRA decomposition error: {error_msg}") traceback.print_exc() else: reason = "not eligible for selected architecture" if architecture != "auto" else f"ndim={weight.ndim}, min(shape)={min(weight.shape) if weight.ndim > 0 else 'N/A'}" lora_stats['layers_skipped'].append(f"{key}: {reason}") else: sd_fp8[key] = weight lora_stats['layers_skipped'].append(f"{key}: unsupported dtype {weight.dtype}") base_name = os.path.splitext(os.path.basename(safetensors_path))[0] fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors") lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors") save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata}) # Always save LoRA file, even if empty lora_metadata = { "format": "pt", "lora_rank": str(lora_rank), "architecture": architecture, "stats": json.dumps(lora_stats) } save_file(lora_weights, lora_path, metadata=lora_metadata) # Generate detailed statistics message stats_msg = f""" šŸ“Š LoRA Extraction Statistics: - Total layers analyzed: {lora_stats['layers_analyzed']} - Layers eligible for LoRA: {lora_stats['layers_eligible']} - Successfully processed: {lora_stats['layers_processed']} - Architecture: {architecture} - FP8 Format: {fp8_format.upper()} Top skipped layers: {chr(10).join(lora_stats['layers_skipped'][:10])} """ progress(0.9, desc="Saved FP8 and LoRA files.") progress(1.0, desc="āœ… FP8 + LoRA extraction complete!") if lora_stats['layers_processed'] == 0: stats_msg += "\n\nāš ļø WARNING: No LoRA weights were generated. Try a different architecture selection or lower rank." return True, f"FP8 ({fp8_format}) and rank-{lora_rank} LoRA saved.\n{stats_msg}", lora_stats except Exception as e: error_msg = f"Conversion error: {str(e)}\n{traceback.format_exc()}" print(error_msg) return False, error_msg, None def parse_hf_url(url): url = url.strip().rstrip("/") if not url.startswith("https://huggingface.co/"): raise ValueError("URL must start with https://huggingface.co/") path = url.replace("https://huggingface.co/", "") parts = path.split("/") if len(parts) < 2: raise ValueError("Invalid repo format") repo_id = "/".join(parts[:2]) subfolder = "" if len(parts) > 3 and parts[2] == "tree": subfolder = "/".join(parts[4:]) if len(parts) > 4 else "" elif len(parts) > 2: subfolder = "/".join(parts[2:]) return repo_id, subfolder def download_safetensors_file(source_type, repo_url, filename, hf_token=None, progress=gr.Progress()): temp_dir = tempfile.mkdtemp() try: if source_type == "huggingface": repo_id, subfolder = parse_hf_url(repo_url) safetensors_path = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder or None, cache_dir=temp_dir, token=hf_token, resume_download=True ) elif source_type == "modelscope": if not MODELScope_AVAILABLE: raise ImportError("ModelScope not installed") repo_id = repo_url.strip() safetensors_path = ms_file_download(model_id=repo_id, file_path=filename) else: raise ValueError("Unknown source") return safetensors_path, temp_dir except Exception as e: shutil.rmtree(temp_dir, ignore_errors=True) raise e def upload_to_target(target_type, new_repo_id, output_dir, fp8_format, architecture, hf_token=None, modelscope_token=None, private_repo=False): if target_type == "huggingface": api = HfApi(token=hf_token) api.create_repo(repo_id=new_repo_id, private=private_repo, repo_type="model", exist_ok=True) api.upload_folder(repo_id=new_repo_id, folder_path=output_dir, repo_type="model", token=hf_token) return f"https://huggingface.co/{new_repo_id}" elif target_type == "modelscope": api = ModelScopeApi() if modelscope_token: api.login(modelscope_token) api.push_model(model_id=new_repo_id, model_dir=output_dir) return f"https://modelscope.cn/models/{new_repo_id}" else: raise ValueError("Unknown target") def process_and_upload_fp8( source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture, target_type, new_repo_id, hf_token, modelscope_token, private_repo, progress=gr.Progress() ): if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id): return None, "āŒ Invalid repo ID format. Use 'username/model-name'.", "" if source_type == "huggingface" and not hf_token: return None, "āŒ Hugging Face token required for source.", "" if target_type == "huggingface" and not hf_token: return None, "āŒ Hugging Face token required for target.", "" # Validate lora_rank if lora_rank < 4: return None, "āŒ LoRA rank must be at least 4.", "" temp_dir = None output_dir = tempfile.mkdtemp() try: progress(0.05, desc="Downloading model...") safetensors_path, temp_dir = download_safetensors_file( source_type, repo_url, safetensors_filename, hf_token, progress ) progress(0.25, desc=f"Converting to FP8 with LoRA ({architecture})...") success, msg, stats = convert_safetensors_to_fp8_with_lora( safetensors_path, output_dir, fp8_format, lora_rank, architecture, progress ) if not success: return None, f"āŒ Conversion failed: {msg}", "" progress(0.9, desc="Uploading...") repo_url_final = upload_to_target( target_type, new_repo_id, output_dir, fp8_format, architecture, hf_token, modelscope_token, private_repo ) base_name = os.path.splitext(safetensors_filename)[0] lora_filename = f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors" fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors" readme = f"""--- library_name: diffusers tags: - fp8 - safetensors - lora - low-rank - diffusion - architecture-{architecture} - converted-by-ai-toolkit --- # FP8 Model with Low-Rank LoRA - **Source**: `{repo_url}` - **File**: `{safetensors_filename}` - **FP8 Format**: `{fp8_format.upper()}` - **LoRA Rank**: {lora_rank} - **Architecture Target**: {architecture} - **LoRA File**: `{lora_filename}` - **FP8 File**: `{fp8_filename}` ## Architecture Distribution """ # Add architecture stats to README if available if stats and 'architecture_distro' in stats: readme += "\n| Component | Layer Count |\n|-----------|------------|\n" for arch, count in stats['architecture_distro'].items(): readme += f"| {arch.replace('_', ' ').title()} | {count} |\n" readme += f""" ## Usage (Inference) ```python from safetensors.torch import load_file import torch # Load FP8 model fp8_state = load_file("{fp8_filename}") lora_state = load_file("{lora_filename}") # Reconstruct approximate original weights reconstructed = {{}} for key in fp8_state: lora_a_key = f"lora_A.{{key}}" lora_b_key = f"lora_B.{{key}}" if lora_a_key in lora_state and lora_b_key in lora_state: A = lora_state[lora_a_key].to(torch.float32) B = lora_state[lora_b_key].to(torch.float32) # Handle different tensor dimensions if A.ndim == 2 and B.ndim == 2: lora_weight = B @ A elif A.ndim == 4 and B.ndim == 4: # For convolutional LoRA lora_weight = F.conv2d(fp8_state[key].to(torch.float32), B, padding=1) + F.conv2d(fp8_state[key].to(torch.float32), A, padding=1) else: # Fallback for mixed dimension cases lora_weight = B @ A.view(B.shape[1], -1) if lora_weight.shape != fp8_state[key].shape: lora_weight = lora_weight.view_as(fp8_state[key]) reconstructed[key] = fp8_state[key].to(torch.float32) + lora_weight else: reconstructed[key] = fp8_state[key].to(torch.float32) ``` > **Note**: Requires PyTorch ≄ 2.1 for FP8 support. For best results, use the same architecture selection ({architecture}) during inference as was used during extraction. """ with open(os.path.join(output_dir, "README.md"), "w") as f: f.write(readme) if target_type == "huggingface": HfApi(token=hf_token).upload_file( path_or_fileobj=os.path.join(output_dir, "README.md"), path_in_repo="README.md", repo_id=new_repo_id, repo_type="model", token=hf_token ) progress(1.0, desc="āœ… Done!") result_html = f""" āœ… Success! Model uploaded to: {new_repo_id} Includes: - FP8 model: `{fp8_filename}` - LoRA weights: `{lora_filename}` (rank {lora_rank}, architecture: {architecture}) šŸ“Š Stats: {stats['layers_processed']}/{stats['layers_eligible']} eligible layers processed """ return gr.HTML(result_html), "āœ… FP8 + LoRA upload successful!", msg except Exception as e: error_msg = f"āŒ Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) return None, error_msg, "" finally: if temp_dir: shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(output_dir, ignore_errors=True) with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo: gr.Markdown("# šŸ”„ Advanced FP8 Pruner with Architecture-Specific LoRA Extraction") gr.Markdown("Convert `.safetensors` → **FP8** + **targeted LoRA** weights for precision recovery. Supports Hugging Face ↔ ModelScope.") with gr.Row(): with gr.Column(): source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source") repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id") safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors") with gr.Accordion("Advanced LoRA Settings", open=True): fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format") lora_rank = gr.Slider(minimum=4, maximum=256, step=4, value=64, label="LoRA Rank") architecture = gr.Dropdown( choices=[ ("Auto-detect components", "auto"), ("Text Encoder (embeddings, attention)", "text_encoder"), ("UNet Transformers (attention blocks)", "unet_transformer"), ("UNet Convolutions (resnets, downsampling)", "unet_conv"), ("VAE (encoder/decoder)", "vae"), ("All components", "all") ], value="auto", label="Target Architecture", info="Select which model components to apply LoRA to" ) with gr.Accordion("Authentication", open=False): hf_token = gr.Textbox(label="Hugging Face Token", type="password") modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE) with gr.Column(): target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target") new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8-lora") private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False) status_output = gr.Markdown() detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10) convert_btn = gr.Button("šŸš€ Convert & Upload", variant="primary") repo_link_output = gr.HTML() convert_btn.click( fn=process_and_upload_fp8, inputs=[ source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture, target_type, new_repo_id, hf_token, modelscope_token, private_repo ], outputs=[repo_link_output, status_output, detailed_log], show_progress=True ) gr.Examples( examples=[ ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", 64, "unet_transformer"], ["huggingface", "https://huggingface.co/stabilityai/sdxl-vae", "diffusion_pytorch_model.safetensors", "e4m3fn", 32, "vae"], ["huggingface", "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder", "model.safetensors", "e5m2", 48, "text_encoder"] ], inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture], label="Example Conversions" ) gr.Markdown(""" ## šŸ’” Usage Tips - **For Text Encoders**: Use rank 32-64 with `text_encoder` architecture for optimal results. - **For UNet Attention**: Use `unet_transformer` with rank 64-128 for best quality preservation. - **For UNet Convolutions**: Use `unet_conv` with lower ranks (16-32) as convolutions compress better. - **For VAE**: Use `vae` architecture with rank 16-32. - **Auto Mode**: Let the tool analyze and target appropriate layers automatically. āš ļø **Note**: Higher ranks produce better quality but larger LoRA files. Start with lower ranks and increase if needed. """) demo.launch()