import gradio as gr import os import tempfile import shutil import re import json import datetime from pathlib import Path from huggingface_hub import HfApi, hf_hub_download from safetensors.torch import load_file, save_file import torch import torch.nn.functional as F import traceback import math try: from modelscope.hub.file_download import model_file_download as ms_file_download from modelscope.hub.api import HubApi as ModelScopeApi MODELScope_AVAILABLE = True except ImportError: MODELScope_AVAILABLE = False def low_rank_decomposition(weight, rank=64, approximation_factor=0.8): """Low-rank decomposition with controlled approximation error.""" original_shape = weight.shape original_dtype = weight.dtype try: # Handle 2D tensors (linear layers, attention) if weight.ndim == 2: # Compute SVD U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False) # Calculate how much variance we want to keep total_variance = torch.sum(S ** 2) cumulative_variance = torch.cumsum(S ** 2, dim=0) # Find minimal rank that preserves approximation_factor of variance minimal_rank = torch.searchsorted(cumulative_variance, approximation_factor * total_variance).item() + 1 # Use the smaller of: requested rank or minimal rank for approximation_factor actual_rank = min(rank, len(S)) # If actual_rank is too close to full rank, reduce it to create meaningful approximation if actual_rank > len(S) * 0.8: # If using more than 80% of full rank actual_rank = max(min(rank // 2, len(S) // 2), 8) # Use half the requested rank # Ensure we're actually approximating, not just reparameterizing if actual_rank >= min(weight.shape): # Force approximation by using lower rank actual_rank = max(min(weight.shape) // 4, 8) U_k = U[:, :actual_rank] @ torch.diag(torch.sqrt(S[:actual_rank])) Vh_k = torch.diag(torch.sqrt(S[:actual_rank])) @ Vh[:actual_rank, :] return U_k.contiguous(), Vh_k.contiguous() # Handle 4D tensors (convolutional layers) elif weight.ndim == 4: out_ch, in_ch, kH, kW = weight.shape # Reshape to 2D for SVD weight_2d = weight.view(out_ch, in_ch * kH * kW) # Compute SVD on flattened version U, S, Vh = torch.linalg.svd(weight_2d.float(), full_matrices=False) # Calculate appropriate rank total_variance = torch.sum(S ** 2) cumulative_variance = torch.cumsum(S ** 2, dim=0) minimal_rank = torch.searchsorted(cumulative_variance, approximation_factor * total_variance).item() + 1 # Adjust rank for convolutions - typically need lower ranks conv_rank = min(rank // 2, len(S)) if conv_rank > len(S) * 0.7: conv_rank = max(len(S) // 4, 8) actual_rank = max(min(conv_rank, minimal_rank), 8) # Decompose U_k = U[:, :actual_rank] @ torch.diag(torch.sqrt(S[:actual_rank])) Vh_k = torch.diag(torch.sqrt(S[:actual_rank])) @ Vh[:actual_rank, :] # Reshape back to convolutional format if kH == 1 and kW == 1: # 1x1 convolutions U_k = U_k.view(out_ch, actual_rank, 1, 1) Vh_k = Vh_k.view(actual_rank, in_ch, 1, 1) else: # For larger kernels, use spatial decomposition U_k = U_k.view(out_ch, actual_rank, 1, 1) Vh_k = Vh_k.view(actual_rank, in_ch, kH, kW) return U_k.contiguous(), Vh_k.contiguous() # Handle 1D tensors (biases, embeddings) elif weight.ndim == 1: # Don't decompose 1D tensors return None, None except Exception as e: print(f"Decomposition error for tensor with shape {original_shape}: {str(e)[:100]}") return None, None def get_architecture_specific_settings(architecture, base_rank): """Get optimal settings for different model architectures.""" settings = { "text_encoder": { "rank": base_rank, "approximation_factor": 0.95, # Text encoders need high accuracy "min_rank": 8, "max_rank_factor": 0.5 # Use at most 50% of full rank }, "unet_transformer": { "rank": base_rank, "approximation_factor": 0.90, "min_rank": 16, "max_rank_factor": 0.4 }, "unet_conv": { "rank": base_rank // 2, # Convolutions compress better "approximation_factor": 0.85, "min_rank": 8, "max_rank_factor": 0.3 }, "vae": { "rank": base_rank // 3, # VAE compresses very well "approximation_factor": 0.80, "min_rank": 4, "max_rank_factor": 0.25 }, "auto": { "rank": base_rank, "approximation_factor": 0.90, "min_rank": 8, "max_rank_factor": 0.5 }, "all": { "rank": base_rank, "approximation_factor": 0.90, "min_rank": 8, "max_rank_factor": 0.5 } } return settings.get(architecture, settings["auto"]) def should_apply_lora(key, weight, architecture, lora_rank): """Determine if LoRA should be applied to a specific weight based on architecture selection.""" # Skip bias terms, batchnorm, and very small tensors if 'bias' in key or 'norm' in key.lower() or 'bn' in key.lower(): return False # Skip very small tensors if weight.numel() < 100: return False # Skip 1D tensors if weight.ndim == 1: return False # Architecture-specific rules lower_key = key.lower() if architecture == "text_encoder": # Text encoder: focus on embeddings and attention layers return ('emb' in lower_key or 'embed' in lower_key or 'attn' in lower_key or 'qkv' in lower_key or 'mlp' in lower_key) elif architecture == "unet_transformer": # UNet transformers: focus on attention blocks return ('attn' in lower_key or 'transformer' in lower_key or 'qkv' in lower_key or 'to_out' in lower_key) elif architecture == "unet_conv": # UNet convolutional layers return ('conv' in lower_key or 'resnet' in lower_key or 'downsample' in lower_key or 'upsample' in lower_key) elif architecture == "vae": # VAE components return ('encoder' in lower_key or 'decoder' in lower_key or 'conv' in lower_key or 'post_quant' in lower_key) elif architecture == "all": # Apply to all eligible tensors return True elif architecture == "auto": # Auto-detect based on tensor properties if weight.ndim == 2 and min(weight.shape) > lora_rank // 4: return True if weight.ndim == 4 and (weight.shape[0] > lora_rank // 4 or weight.shape[1] > lora_rank // 4): return True return False return False def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=64, architecture="auto", progress=gr.Progress()): progress(0.1, desc="Starting FP8 conversion with LoRA extraction...") try: def read_safetensors_metadata(path): with open(path, 'rb') as f: header_size = int.from_bytes(f.read(8), 'little') header_json = f.read(header_size).decode('utf-8') header = json.loads(header_json) return header.get('__metadata__', {}) metadata = read_safetensors_metadata(safetensors_path) progress(0.2, desc="Loaded metadata.") state_dict = load_file(safetensors_path) progress(0.4, desc="Loaded weights.") # Architecture analysis architecture_stats = { 'text_encoder': 0, 'unet_transformer': 0, 'unet_conv': 0, 'vae': 0, 'other': 0 } for key in state_dict: lower_key = key.lower() if 'text' in lower_key or 'emb' in lower_key: architecture_stats['text_encoder'] += 1 elif 'attn' in lower_key or 'transformer' in lower_key: architecture_stats['unet_transformer'] += 1 elif 'conv' in lower_key or 'resnet' in lower_key: architecture_stats['unet_conv'] += 1 elif 'vae' in lower_key or 'encoder' in lower_key or 'decoder' in lower_key: architecture_stats['vae'] += 1 else: architecture_stats['other'] += 1 print("Architecture analysis:") for arch, count in architecture_stats.items(): print(f"- {arch}: {count} layers") if fp8_format == "e5m2": fp8_dtype = torch.float8_e5m2 else: fp8_dtype = torch.float8_e4m3fn sd_fp8 = {} lora_weights = {} lora_stats = { 'total_layers': len(state_dict), 'layers_analyzed': 0, 'layers_eligible': 0, 'layers_processed': 0, 'layers_skipped': [], 'architecture_distro': architecture_stats, 'reconstruction_errors': [] } total = len(state_dict) lora_keys = [] for i, key in enumerate(state_dict): progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}: {key.split('.')[-1]}") weight = state_dict[key] lora_stats['layers_analyzed'] += 1 if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]: fp8_weight = weight.to(fp8_dtype) sd_fp8[key] = fp8_weight # Determine if we should apply LoRA eligible_for_lora = should_apply_lora(key, weight, architecture, lora_rank) if eligible_for_lora: lora_stats['layers_eligible'] += 1 try: # Get architecture-specific settings arch_settings = get_architecture_specific_settings(architecture, lora_rank) # Adjust rank based on tensor properties if weight.ndim == 2: max_possible_rank = min(weight.shape) actual_rank = min( arch_settings["rank"], int(max_possible_rank * arch_settings["max_rank_factor"]) ) actual_rank = max(actual_rank, arch_settings["min_rank"]) elif weight.ndim == 4: # For conv layers, use smaller rank actual_rank = min( arch_settings["rank"], max(weight.shape[0], weight.shape[1]) // 4 ) actual_rank = max(actual_rank, arch_settings["min_rank"]) else: # Skip non-2D/4D tensors for LoRA lora_stats['layers_skipped'].append(f"{key}: unsupported ndim={weight.ndim}") continue if actual_rank < 4: lora_stats['layers_skipped'].append(f"{key}: rank too small ({actual_rank})") continue # Perform decomposition with approximation U, V = low_rank_decomposition( weight, rank=actual_rank, approximation_factor=arch_settings["approximation_factor"] ) if U is not None and V is not None: # Store as half-precision lora_weights[f"lora_A.{key}"] = U.to(torch.float16) lora_weights[f"lora_B.{key}"] = V.to(torch.float16) lora_keys.append(key) lora_stats['layers_processed'] += 1 # Calculate and store reconstruction error if U.ndim == 2 and V.ndim == 2: if V.shape[0] == U.shape[1]: reconstructed = V @ U else: reconstructed = U @ V error = torch.norm(weight.float() - reconstructed.float()) / torch.norm(weight.float()) lora_stats['reconstruction_errors'].append({ 'key': key, 'error': error.item(), 'original_shape': list(weight.shape), 'rank': actual_rank }) else: lora_stats['layers_skipped'].append(f"{key}: decomposition returned None") except Exception as e: error_msg = f"{key}: {str(e)[:100]}" lora_stats['layers_skipped'].append(error_msg) else: reason = "not eligible for selected architecture" if architecture != "auto" else f"ndim={weight.ndim}" lora_stats['layers_skipped'].append(f"{key}: {reason}") else: sd_fp8[key] = weight lora_stats['layers_skipped'].append(f"{key}: unsupported dtype {weight.dtype}") # Add reconstruction error statistics if lora_stats['reconstruction_errors']: errors = [e['error'] for e in lora_stats['reconstruction_errors']] lora_stats['avg_reconstruction_error'] = sum(errors) / len(errors) if errors else 0 lora_stats['max_reconstruction_error'] = max(errors) if errors else 0 lora_stats['min_reconstruction_error'] = min(errors) if errors else 0 base_name = os.path.splitext(os.path.basename(safetensors_path))[0] fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors") lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors") save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata}) # Always save LoRA file, even if empty lora_metadata = { "format": "pt", "lora_rank": str(lora_rank), "architecture": architecture, "original_filename": os.path.basename(safetensors_path), "fp8_format": fp8_format, "stats": json.dumps(lora_stats) } save_file(lora_weights, lora_path, metadata=lora_metadata) # Generate detailed statistics message stats_msg = f""" šŸ“Š LoRA Extraction Statistics: - Total layers analyzed: {lora_stats['layers_analyzed']} - Layers eligible for LoRA: {lora_stats['layers_eligible']} - Successfully processed: {lora_stats['layers_processed']} - Architecture: {architecture} - FP8 Format: {fp8_format.upper()} """ if 'avg_reconstruction_error' in lora_stats: stats_msg += f"- Avg reconstruction error: {lora_stats['avg_reconstruction_error']:.6f}\n" stats_msg += f"- Max reconstruction error: {lora_stats['max_reconstruction_error']:.6f}\n" progress(0.9, desc="Saved FP8 and LoRA files.") progress(1.0, desc="āœ… FP8 + LoRA extraction complete!") if lora_stats['layers_processed'] == 0: stats_msg += "\n\nāš ļø WARNING: No LoRA weights were generated. Try a different architecture selection or lower rank." elif lora_stats.get('avg_reconstruction_error', 1) < 0.0001: stats_msg += "\n\nā„¹ļø NOTE: Very low reconstruction error detected. LoRA may be reconstructing almost perfectly. Consider using lower rank for better compression." return True, f"FP8 ({fp8_format}) and rank-{lora_rank} LoRA saved.\n{stats_msg}", lora_stats except Exception as e: error_msg = f"Conversion error: {str(e)}\n{traceback.format_exc()}" print(error_msg) return False, error_msg, None def parse_hf_url(url): url = url.strip().rstrip("/") if not url.startswith("https://huggingface.co/"): raise ValueError("URL must start with https://huggingface.co/") path = url.replace("https://huggingface.co/", "") parts = path.split("/") if len(parts) < 2: raise ValueError("Invalid repo format") repo_id = "/".join(parts[:2]) subfolder = "" if len(parts) > 3 and parts[2] == "tree": subfolder = "/".join(parts[4:]) if len(parts) > 4 else "" elif len(parts) > 2: subfolder = "/".join(parts[2:]) return repo_id, subfolder def download_safetensors_file(source_type, repo_url, filename, hf_token=None, progress=gr.Progress()): temp_dir = tempfile.mkdtemp() try: if source_type == "huggingface": repo_id, subfolder = parse_hf_url(repo_url) safetensors_path = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder or None, cache_dir=temp_dir, token=hf_token, resume_download=True ) elif source_type == "modelscope": if not MODELScope_AVAILABLE: raise ImportError("ModelScope not installed") repo_id = repo_url.strip() safetensors_path = ms_file_download(model_id=repo_id, file_path=filename) else: raise ValueError("Unknown source") return safetensors_path, temp_dir except Exception as e: shutil.rmtree(temp_dir, ignore_errors=True) raise e def upload_to_target(target_type, new_repo_id, output_dir, fp8_format, architecture, hf_token=None, modelscope_token=None, private_repo=False): if target_type == "huggingface": api = HfApi(token=hf_token) api.create_repo(repo_id=new_repo_id, private=private_repo, repo_type="model", exist_ok=True) api.upload_folder(repo_id=new_repo_id, folder_path=output_dir, repo_type="model", token=hf_token) return f"https://huggingface.co/{new_repo_id}" elif target_type == "modelscope": api = ModelScopeApi() if modelscope_token: api.login(modelscope_token) api.push_model(model_id=new_repo_id, model_dir=output_dir) return f"https://modelscope.cn/models/{new_repo_id}" else: raise ValueError("Unknown target") def process_and_upload_fp8( source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture, target_type, new_repo_id, hf_token, modelscope_token, private_repo, progress=gr.Progress() ): if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id): return None, "āŒ Invalid repo ID format. Use 'username/model-name'.", "" if source_type == "huggingface" and not hf_token: return None, "āŒ Hugging Face token required for source.", "" if target_type == "huggingface" and not hf_token: return None, "āŒ Hugging Face token required for target.", "" # Validate lora_rank if lora_rank < 4: return None, "āŒ LoRA rank must be at least 4.", "" temp_dir = None output_dir = tempfile.mkdtemp() try: progress(0.05, desc="Downloading model...") safetensors_path, temp_dir = download_safetensors_file( source_type, repo_url, safetensors_filename, hf_token, progress ) progress(0.25, desc=f"Converting to FP8 with LoRA ({architecture})...") success, msg, stats = convert_safetensors_to_fp8_with_lora( safetensors_path, output_dir, fp8_format, lora_rank, architecture, progress ) if not success: return None, f"āŒ Conversion failed: {msg}", "" progress(0.9, desc="Uploading...") repo_url_final = upload_to_target( target_type, new_repo_id, output_dir, fp8_format, architecture, hf_token, modelscope_token, private_repo ) base_name = os.path.splitext(safetensors_filename)[0] lora_filename = f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors" fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors" readme = f"""--- library_name: diffusers tags: - fp8 - safetensors - lora - low-rank - diffusion - architecture-{architecture} - converted-by-ai-toolkit --- # FP8 Model with Low-Rank LoRA - **Source**: `{repo_url}` - **File**: `{safetensors_filename}` - **FP8 Format**: `{fp8_format.upper()}` - **LoRA Rank**: {lora_rank} - **Architecture Target**: {architecture} - **LoRA File**: `{lora_filename}` - **FP8 File**: `{fp8_filename}` ## Architecture Distribution """ # Add architecture stats to README if available if stats and 'architecture_distro' in stats: readme += "\n| Component | Layer Count |\n|-----------|------------|\n" for arch, count in stats['architecture_distro'].items(): readme += f"| {arch.replace('_', ' ').title()} | {count} |\n" readme += f""" ## Usage (Inference) ```python from safetensors.torch import load_file import torch # Load FP8 model fp8_state = load_file("{fp8_filename}") lora_state = load_file("{lora_filename}") # Reconstruct approximate original weights reconstructed = {{}} for key in fp8_state: lora_a_key = f"lora_A.{{key}}" lora_b_key = f"lora_B.{{key}}" if lora_a_key in lora_state and lora_b_key in lora_state: A = lora_state[lora_a_key].to(torch.float32) B = lora_state[lora_b_key].to(torch.float32) # Handle different tensor dimensions if A.ndim == 2 and B.ndim == 2: lora_weight = B @ A elif A.ndim == 4 and B.ndim == 4: # For convolutional LoRA lora_weight = F.conv2d(fp8_state[key].to(torch.float32), B, padding=1) + F.conv2d(fp8_state[key].to(torch.float32), A, padding=1) else: # Fallback for mixed dimension cases lora_weight = B @ A.view(B.shape[1], -1) if lora_weight.shape != fp8_state[key].shape: lora_weight = lora_weight.view_as(fp8_state[key]) reconstructed[key] = fp8_state[key].to(torch.float32) + lora_weight else: reconstructed[key] = fp8_state[key].to(torch.float32) ``` > **Note**: Requires PyTorch ≄ 2.1 for FP8 support. For best results, use the same architecture selection ({architecture}) during inference as was used during extraction. """ with open(os.path.join(output_dir, "README.md"), "w") as f: f.write(readme) if target_type == "huggingface": HfApi(token=hf_token).upload_file( path_or_fileobj=os.path.join(output_dir, "README.md"), path_in_repo="README.md", repo_id=new_repo_id, repo_type="model", token=hf_token ) progress(1.0, desc="āœ… Done!") result_html = f""" āœ… Success! Model uploaded to: {new_repo_id} Includes: - FP8 model: `{fp8_filename}` - LoRA weights: `{lora_filename}` (rank {lora_rank}, architecture: {architecture}) šŸ“Š Stats: {stats['layers_processed']}/{stats['layers_eligible']} eligible layers processed """ if 'avg_reconstruction_error' in stats: result_html += f"
Avg reconstruction error: {stats['avg_reconstruction_error']:.6f}" return gr.HTML(result_html), "āœ… FP8 + LoRA upload successful!", msg except Exception as e: error_msg = f"āŒ Error: {str(e)}\n{traceback.format_exc()}" print(error_msg) return None, error_msg, "" finally: if temp_dir: shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(output_dir, ignore_errors=True) with gr.Blocks(title="FP8 + LoRA Extractor (HF ↔ ModelScope)") as demo: gr.Markdown("# šŸ”„ Advanced FP8 Pruner with Architecture-Specific LoRA Extraction") gr.Markdown("Convert `.safetensors` → **FP8** + **targeted LoRA** weights for precision recovery. Supports Hugging Face ↔ ModelScope.") with gr.Row(): with gr.Column(): source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source") repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id") safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors") with gr.Accordion("Advanced LoRA Settings", open=True): fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format") lora_rank = gr.Slider(minimum=4, maximum=256, step=4, value=64, label="LoRA Rank") architecture = gr.Dropdown( choices=[ ("Auto-detect components", "auto"), ("Text Encoder (embeddings, attention)", "text_encoder"), ("UNet Transformers (attention blocks)", "unet_transformer"), ("UNet Convolutions (resnets, downsampling)", "unet_conv"), ("VAE (encoder/decoder)", "vae"), ("All components", "all") ], value="auto", label="Target Architecture", info="Select which model components to apply LoRA to" ) with gr.Accordion("Authentication", open=False): hf_token = gr.Textbox(label="Hugging Face Token", type="password") modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE) with gr.Column(): target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target") new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8-lora") private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False) status_output = gr.Markdown() detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10) convert_btn = gr.Button("šŸš€ Convert & Upload", variant="primary") repo_link_output = gr.HTML() convert_btn.click( fn=process_and_upload_fp8, inputs=[ source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture, target_type, new_repo_id, hf_token, modelscope_token, private_repo ], outputs=[repo_link_output, status_output, detailed_log], show_progress=True ) gr.Examples( examples=[ ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", 64, "unet_transformer"], ["huggingface", "https://huggingface.co/stabilityai/sdxl-vae", "diffusion_pytorch_model.safetensors", "e4m3fn", 32, "vae"], ["huggingface", "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder", "model.safetensors", "e5m2", 48, "text_encoder"] ], inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture], label="Example Conversions" ) gr.Markdown(""" ## šŸ’” Usage Tips - **For Text Encoders**: Use rank 32-64 with `text_encoder` architecture for optimal results. - **For UNet Attention**: Use `unet_transformer` with rank 64-128 for best quality preservation. - **For UNet Convolutions**: Use `unet_conv` with lower ranks (16-32) as convolutions compress better. - **For VAE**: Use `vae` architecture with rank 16-32. - **Auto Mode**: Let the tool analyze and target appropriate layers automatically. āš ļø **Note**: Higher ranks produce better quality but larger LoRA files. Start with lower ranks and increase if needed. """) demo.launch()