Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,6 +9,7 @@ from pathlib import Path
|
|
| 9 |
from huggingface_hub import HfApi, hf_hub_download
|
| 10 |
from safetensors.torch import load_file, save_file
|
| 11 |
import torch
|
|
|
|
| 12 |
|
| 13 |
try:
|
| 14 |
from modelscope.hub.file_download import model_file_download as ms_file_download
|
|
@@ -17,8 +18,16 @@ try:
|
|
| 17 |
except ImportError:
|
| 18 |
MODELScope_AVAILABLE = False
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
try:
|
| 24 |
def read_safetensors_metadata(path):
|
|
@@ -40,31 +49,41 @@ def convert_safetensors_to_fp8_with_delta(safetensors_path, output_dir, fp8_form
|
|
| 40 |
fp8_dtype = torch.float8_e4m3fn
|
| 41 |
|
| 42 |
sd_fp8 = {}
|
| 43 |
-
|
| 44 |
total = len(state_dict)
|
|
|
|
| 45 |
|
| 46 |
for i, key in enumerate(state_dict):
|
| 47 |
progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
|
| 48 |
weight = state_dict[key]
|
| 49 |
if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
| 50 |
fp8_weight = weight.to(fp8_dtype)
|
| 51 |
-
fp8_recon = fp8_weight.to(weight.dtype)
|
| 52 |
-
delta = weight - fp8_recon
|
| 53 |
sd_fp8[key] = fp8_weight
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
else:
|
| 56 |
sd_fp8[key] = weight
|
| 57 |
|
| 58 |
base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
|
| 59 |
fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
|
| 60 |
-
|
| 61 |
|
| 62 |
save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
-
progress(0.9, desc="Saved FP8 and
|
| 66 |
-
progress(1.0, desc="β
FP8 +
|
| 67 |
-
return True, f"FP8 ({fp8_format}) and
|
| 68 |
|
| 69 |
except Exception as e:
|
| 70 |
return False, str(e)
|
|
@@ -130,6 +149,7 @@ def process_and_upload_fp8(
|
|
| 130 |
repo_url,
|
| 131 |
safetensors_filename,
|
| 132 |
fp8_format,
|
|
|
|
| 133 |
target_type,
|
| 134 |
new_repo_id,
|
| 135 |
hf_token,
|
|
@@ -154,8 +174,10 @@ def process_and_upload_fp8(
|
|
| 154 |
source_type, repo_url, safetensors_filename, hf_token, progress
|
| 155 |
)
|
| 156 |
|
| 157 |
-
progress(0.25, desc="Converting to FP8 with
|
| 158 |
-
success, msg =
|
|
|
|
|
|
|
| 159 |
if not success:
|
| 160 |
return None, f"β Conversion failed: {msg}", ""
|
| 161 |
|
|
@@ -165,42 +187,47 @@ def process_and_upload_fp8(
|
|
| 165 |
)
|
| 166 |
|
| 167 |
base_name = os.path.splitext(safetensors_filename)[0]
|
|
|
|
| 168 |
readme = f"""---
|
| 169 |
library_name: diffusers
|
| 170 |
tags:
|
| 171 |
- fp8
|
| 172 |
- safetensors
|
| 173 |
-
-
|
|
|
|
| 174 |
- diffusion
|
| 175 |
- converted-by-gradio
|
| 176 |
---
|
| 177 |
|
| 178 |
-
# FP8 Model with
|
| 179 |
|
| 180 |
- **Source**: `{repo_url}`
|
| 181 |
- **File**: `{safetensors_filename}`
|
| 182 |
- **FP8 Format**: `{fp8_format.upper()}`
|
| 183 |
-
- **
|
|
|
|
| 184 |
|
| 185 |
## Usage (Inference)
|
| 186 |
|
| 187 |
-
To restore near-original precision:
|
| 188 |
-
|
| 189 |
```python
|
| 190 |
-
import torch
|
| 191 |
from safetensors.torch import load_file
|
|
|
|
| 192 |
|
|
|
|
| 193 |
fp8_state = load_file("{base_name}-fp8-{fp8_format}.safetensors")
|
| 194 |
-
|
| 195 |
|
| 196 |
-
|
|
|
|
| 197 |
for key in fp8_state:
|
| 198 |
-
if f"
|
|
|
|
|
|
|
|
|
|
| 199 |
fp8_weight = fp8_state[key].to(torch.float32)
|
| 200 |
-
|
| 201 |
-
restored_state[key] = fp8_weight + delta
|
| 202 |
else:
|
| 203 |
-
|
| 204 |
```
|
| 205 |
|
| 206 |
> Requires PyTorch β₯ 2.1 for FP8 support.
|
|
@@ -221,9 +248,9 @@ for key in fp8_state:
|
|
| 221 |
result_html = f"""
|
| 222 |
β
Success!
|
| 223 |
Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
|
| 224 |
-
Includes: FP8 model +
|
| 225 |
"""
|
| 226 |
-
return gr.HTML(result_html), "β
FP8 +
|
| 227 |
|
| 228 |
except Exception as e:
|
| 229 |
return None, f"β Error: {str(e)}", ""
|
|
@@ -232,9 +259,9 @@ Includes: FP8 model + delta compensation file.
|
|
| 232 |
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 233 |
shutil.rmtree(output_dir, ignore_errors=True)
|
| 234 |
|
| 235 |
-
with gr.Blocks(title="FP8 +
|
| 236 |
-
gr.Markdown("# π FP8 Pruner with
|
| 237 |
-
gr.Markdown("Convert `.safetensors` β **FP8** + **
|
| 238 |
|
| 239 |
with gr.Row():
|
| 240 |
with gr.Column():
|
|
@@ -242,6 +269,7 @@ with gr.Blocks(title="FP8 + Delta Converter (HF β ModelScope)") as demo:
|
|
| 242 |
repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
|
| 243 |
safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors")
|
| 244 |
fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
|
|
|
|
| 245 |
hf_token = gr.Textbox(label="HF Token (only if using HF)", type="password")
|
| 246 |
modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
|
| 247 |
with gr.Column():
|
|
@@ -260,6 +288,7 @@ with gr.Blocks(title="FP8 + Delta Converter (HF β ModelScope)") as demo:
|
|
| 260 |
repo_url,
|
| 261 |
safetensors_filename,
|
| 262 |
fp8_format,
|
|
|
|
| 263 |
target_type,
|
| 264 |
new_repo_id,
|
| 265 |
hf_token,
|
|
@@ -272,9 +301,9 @@ with gr.Blocks(title="FP8 + Delta Converter (HF β ModelScope)") as demo:
|
|
| 272 |
|
| 273 |
gr.Examples(
|
| 274 |
examples=[
|
| 275 |
-
["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", "modelscope"]
|
| 276 |
],
|
| 277 |
-
inputs=[source_type, repo_url, safetensors_filename, fp8_format, target_type]
|
| 278 |
)
|
| 279 |
|
| 280 |
demo.launch()
|
|
|
|
| 9 |
from huggingface_hub import HfApi, hf_hub_download
|
| 10 |
from safetensors.torch import load_file, save_file
|
| 11 |
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
|
| 14 |
try:
|
| 15 |
from modelscope.hub.file_download import model_file_download as ms_file_download
|
|
|
|
| 18 |
except ImportError:
|
| 19 |
MODELScope_AVAILABLE = False
|
| 20 |
|
| 21 |
+
def low_rank_decomposition(weight, rank=64):
|
| 22 |
+
if weight.ndim != 2:
|
| 23 |
+
return None
|
| 24 |
+
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
|
| 25 |
+
U = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
| 26 |
+
Vh = torch.diag(torch.sqrt(S[:rank])) @ Vh[:rank, :]
|
| 27 |
+
return U.contiguous(), Vh.contiguous()
|
| 28 |
+
|
| 29 |
+
def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=64, progress=gr.Progress()):
|
| 30 |
+
progress(0.1, desc="Starting FP8 conversion with LoRA extraction...")
|
| 31 |
|
| 32 |
try:
|
| 33 |
def read_safetensors_metadata(path):
|
|
|
|
| 49 |
fp8_dtype = torch.float8_e4m3fn
|
| 50 |
|
| 51 |
sd_fp8 = {}
|
| 52 |
+
lora_weights = {}
|
| 53 |
total = len(state_dict)
|
| 54 |
+
lora_keys = []
|
| 55 |
|
| 56 |
for i, key in enumerate(state_dict):
|
| 57 |
progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
|
| 58 |
weight = state_dict[key]
|
| 59 |
if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
| 60 |
fp8_weight = weight.to(fp8_dtype)
|
|
|
|
|
|
|
| 61 |
sd_fp8[key] = fp8_weight
|
| 62 |
+
|
| 63 |
+
# Attempt LoRA decomposition only for 2D tensors
|
| 64 |
+
if weight.ndim == 2 and min(weight.shape) > lora_rank:
|
| 65 |
+
try:
|
| 66 |
+
U, V = low_rank_decomposition(weight, rank=lora_rank)
|
| 67 |
+
if U is not None and V is not None:
|
| 68 |
+
lora_weights[f"lora_A.{key}"] = U.to(torch.float16)
|
| 69 |
+
lora_weights[f"lora_B.{key}"] = V.to(torch.float16)
|
| 70 |
+
lora_keys.append(key)
|
| 71 |
+
except Exception:
|
| 72 |
+
pass
|
| 73 |
else:
|
| 74 |
sd_fp8[key] = weight
|
| 75 |
|
| 76 |
base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
|
| 77 |
fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
|
| 78 |
+
lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}.safetensors")
|
| 79 |
|
| 80 |
save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
|
| 81 |
+
if lora_weights:
|
| 82 |
+
save_file(lora_weights, lora_path, metadata={"format": "pt", "lora_rank": str(lora_rank)})
|
| 83 |
|
| 84 |
+
progress(0.9, desc="Saved FP8 and LoRA files.")
|
| 85 |
+
progress(1.0, desc="β
FP8 + LoRA extraction complete!")
|
| 86 |
+
return True, f"FP8 ({fp8_format}) and rank-{lora_rank} LoRA saved."
|
| 87 |
|
| 88 |
except Exception as e:
|
| 89 |
return False, str(e)
|
|
|
|
| 149 |
repo_url,
|
| 150 |
safetensors_filename,
|
| 151 |
fp8_format,
|
| 152 |
+
lora_rank,
|
| 153 |
target_type,
|
| 154 |
new_repo_id,
|
| 155 |
hf_token,
|
|
|
|
| 174 |
source_type, repo_url, safetensors_filename, hf_token, progress
|
| 175 |
)
|
| 176 |
|
| 177 |
+
progress(0.25, desc="Converting to FP8 with LoRA extraction...")
|
| 178 |
+
success, msg = convert_safetensors_to_fp8_with_lora(
|
| 179 |
+
safetensors_path, output_dir, fp8_format, lora_rank, progress
|
| 180 |
+
)
|
| 181 |
if not success:
|
| 182 |
return None, f"β Conversion failed: {msg}", ""
|
| 183 |
|
|
|
|
| 187 |
)
|
| 188 |
|
| 189 |
base_name = os.path.splitext(safetensors_filename)[0]
|
| 190 |
+
lora_filename = f"{base_name}-lora-r{lora_rank}.safetensors"
|
| 191 |
readme = f"""---
|
| 192 |
library_name: diffusers
|
| 193 |
tags:
|
| 194 |
- fp8
|
| 195 |
- safetensors
|
| 196 |
+
- lora
|
| 197 |
+
- low-rank
|
| 198 |
- diffusion
|
| 199 |
- converted-by-gradio
|
| 200 |
---
|
| 201 |
|
| 202 |
+
# FP8 Model with Low-Rank LoRA
|
| 203 |
|
| 204 |
- **Source**: `{repo_url}`
|
| 205 |
- **File**: `{safetensors_filename}`
|
| 206 |
- **FP8 Format**: `{fp8_format.upper()}`
|
| 207 |
+
- **LoRA Rank**: {lora_rank}
|
| 208 |
+
- **LoRA File**: `{lora_filename}`
|
| 209 |
|
| 210 |
## Usage (Inference)
|
| 211 |
|
|
|
|
|
|
|
| 212 |
```python
|
|
|
|
| 213 |
from safetensors.torch import load_file
|
| 214 |
+
import torch
|
| 215 |
|
| 216 |
+
# Load FP8 model
|
| 217 |
fp8_state = load_file("{base_name}-fp8-{fp8_format}.safetensors")
|
| 218 |
+
lora_state = load_file("{lora_filename}")
|
| 219 |
|
| 220 |
+
# Reconstruct approximate original weights
|
| 221 |
+
reconstructed = {{}}
|
| 222 |
for key in fp8_state:
|
| 223 |
+
if f"lora_A.{{key}}" in lora_state and f"lora_B.{{key}}" in lora_state:
|
| 224 |
+
A = lora_state[f"lora_A.{{key}}"].to(torch.float32)
|
| 225 |
+
B = lora_state[f"lora_B.{{key}}"].to(torch.float32)
|
| 226 |
+
lora_weight = B @ A # (rank, out) @ (in, rank) -> (out, in)
|
| 227 |
fp8_weight = fp8_state[key].to(torch.float32)
|
| 228 |
+
reconstructed[key] = fp8_weight + lora_weight
|
|
|
|
| 229 |
else:
|
| 230 |
+
reconstructed[key] = fp8_state[key].to(torch.float32)
|
| 231 |
```
|
| 232 |
|
| 233 |
> Requires PyTorch β₯ 2.1 for FP8 support.
|
|
|
|
| 248 |
result_html = f"""
|
| 249 |
β
Success!
|
| 250 |
Model uploaded to: <a href="{repo_url_final}" target="_blank">{new_repo_id}</a>
|
| 251 |
+
Includes: FP8 model + rank-{lora_rank} LoRA.
|
| 252 |
"""
|
| 253 |
+
return gr.HTML(result_html), "β
FP8 + LoRA upload successful!", ""
|
| 254 |
|
| 255 |
except Exception as e:
|
| 256 |
return None, f"β Error: {str(e)}", ""
|
|
|
|
| 259 |
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 260 |
shutil.rmtree(output_dir, ignore_errors=True)
|
| 261 |
|
| 262 |
+
with gr.Blocks(title="FP8 + LoRA Extractor (HF β ModelScope)") as demo:
|
| 263 |
+
gr.Markdown("# π FP8 Pruner with Low-Rank LoRA Extraction")
|
| 264 |
+
gr.Markdown("Convert `.safetensors` β **FP8** + **compact LoRA** for precision recovery. Supports Hugging Face β ModelScope.")
|
| 265 |
|
| 266 |
with gr.Row():
|
| 267 |
with gr.Column():
|
|
|
|
| 269 |
repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
|
| 270 |
safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors")
|
| 271 |
fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
|
| 272 |
+
lora_rank = gr.Slider(minimum=8, maximum=256, step=8, value=64, label="LoRA Rank")
|
| 273 |
hf_token = gr.Textbox(label="HF Token (only if using HF)", type="password")
|
| 274 |
modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE)
|
| 275 |
with gr.Column():
|
|
|
|
| 288 |
repo_url,
|
| 289 |
safetensors_filename,
|
| 290 |
fp8_format,
|
| 291 |
+
lora_rank,
|
| 292 |
target_type,
|
| 293 |
new_repo_id,
|
| 294 |
hf_token,
|
|
|
|
| 301 |
|
| 302 |
gr.Examples(
|
| 303 |
examples=[
|
| 304 |
+
["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", 64, "modelscope"]
|
| 305 |
],
|
| 306 |
+
inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, target_type]
|
| 307 |
)
|
| 308 |
|
| 309 |
demo.launch()
|