Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,8 +19,9 @@ except ImportError:
|
|
| 19 |
|
| 20 |
def low_rank_decomposition(weight, rank=64):
|
| 21 |
"""
|
| 22 |
-
|
| 23 |
-
Returns (lora_A, lora_B) such that weight β lora_B @ lora_A
|
|
|
|
| 24 |
"""
|
| 25 |
original_shape = weight.shape
|
| 26 |
original_dtype = weight.dtype
|
|
@@ -34,7 +35,7 @@ def low_rank_decomposition(weight, rank=64):
|
|
| 34 |
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
|
| 35 |
S_sqrt = torch.sqrt(S[:actual_rank])
|
| 36 |
|
| 37 |
-
# Standard LoRA: W β W_B @ W_A
|
| 38 |
W_A = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous() # [rank, in_features]
|
| 39 |
W_B = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous() # [out_features, rank]
|
| 40 |
|
|
@@ -42,8 +43,9 @@ def low_rank_decomposition(weight, rank=64):
|
|
| 42 |
|
| 43 |
elif weight.ndim == 4:
|
| 44 |
out_ch, in_ch, k_h, k_w = weight.shape
|
| 45 |
-
if k_h * k_w <= 9: # small
|
| 46 |
-
|
|
|
|
| 47 |
actual_rank = min(rank, min(weight_2d.shape) // 2)
|
| 48 |
if actual_rank < 4:
|
| 49 |
return None, None
|
|
@@ -54,7 +56,8 @@ def low_rank_decomposition(weight, rank=64):
|
|
| 54 |
W_A_2d = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous()
|
| 55 |
W_B_2d = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous()
|
| 56 |
|
| 57 |
-
|
|
|
|
| 58 |
W_B = W_B_2d.view(out_ch, actual_rank, 1, 1).contiguous()
|
| 59 |
|
| 60 |
return W_A.to(original_dtype), W_B.to(original_dtype)
|
|
@@ -62,25 +65,25 @@ def low_rank_decomposition(weight, rank=64):
|
|
| 62 |
return None, None
|
| 63 |
|
| 64 |
except Exception as e:
|
| 65 |
-
print(f"Decomposition
|
| 66 |
return None, None
|
| 67 |
|
| 68 |
def should_apply_lora(key, weight, architecture="auto"):
|
| 69 |
-
"""
|
| 70 |
lower_key = key.lower()
|
| 71 |
|
| 72 |
-
# Skip
|
| 73 |
if 'bias' in lower_key or 'norm' in lower_key or weight.numel() < 256:
|
| 74 |
return False
|
| 75 |
|
| 76 |
if architecture == "text_encoder":
|
| 77 |
-
return any(t in lower_key for t in ['emb', 'embed', 'attn'])
|
| 78 |
elif architecture == "unet_transformer":
|
| 79 |
-
return any(t in lower_key for t in ['attn', 'transformer'])
|
| 80 |
elif architecture == "unet_conv":
|
| 81 |
-
return any(t in lower_key for t in ['conv', 'resnet', 'down', 'up'])
|
| 82 |
elif architecture == "vae":
|
| 83 |
-
return any(t in lower_key for t in ['encoder', 'decoder', 'quant'])
|
| 84 |
else: # "auto" or "all"
|
| 85 |
return weight.ndim in [2, 4]
|
| 86 |
|
|
@@ -112,7 +115,6 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
|
|
| 112 |
|
| 113 |
lora_stats = {
|
| 114 |
'total_layers': total,
|
| 115 |
-
'layers_analyzed': 0,
|
| 116 |
'layers_eligible': 0,
|
| 117 |
'layers_processed': 0,
|
| 118 |
'layers_skipped': [],
|
|
@@ -121,13 +123,11 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
|
|
| 121 |
for i, key in enumerate(state_dict):
|
| 122 |
progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
|
| 123 |
weight = state_dict[key]
|
| 124 |
-
lora_stats['layers_analyzed'] += 1
|
| 125 |
|
| 126 |
if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
| 127 |
fp8_weight = weight.to(fp8_dtype)
|
| 128 |
sd_fp8[key] = fp8_weight
|
| 129 |
|
| 130 |
-
# Apply LoRA based on architecture selection
|
| 131 |
if should_apply_lora(key, weight, architecture):
|
| 132 |
lora_stats['layers_eligible'] += 1
|
| 133 |
|
|
@@ -139,11 +139,11 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
|
|
| 139 |
lora_keys.append(key)
|
| 140 |
lora_stats['layers_processed'] += 1
|
| 141 |
else:
|
| 142 |
-
lora_stats['layers_skipped'].append(f"{key}: decomposition
|
| 143 |
except Exception as e:
|
| 144 |
-
lora_stats['layers_skipped'].append(f"{key}: {
|
| 145 |
else:
|
| 146 |
-
reason = "architecture
|
| 147 |
lora_stats['layers_skipped'].append(f"{key}: skipped ({reason})")
|
| 148 |
else:
|
| 149 |
sd_fp8[key] = weight
|
|
@@ -154,8 +154,6 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
|
|
| 154 |
lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
|
| 155 |
|
| 156 |
save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
|
| 157 |
-
|
| 158 |
-
# Always save LoRA file (even if empty) with stats
|
| 159 |
save_file(lora_weights, lora_path, metadata={
|
| 160 |
"format": "pt",
|
| 161 |
"lora_rank": str(lora_rank),
|
|
@@ -166,17 +164,12 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
|
|
| 166 |
progress(0.9, desc="Saved FP8 and LoRA files.")
|
| 167 |
progress(1.0, desc="β
FP8 + LoRA extraction complete!")
|
| 168 |
|
| 169 |
-
stats_msg = f""
|
| 170 |
-
|
| 171 |
-
- Total layers: {lora_stats['total_layers']}
|
| 172 |
-
- Eligible for LoRA: {lora_stats['layers_eligible']}
|
| 173 |
-
- Successfully processed: {lora_stats['layers_processed']}
|
| 174 |
-
- Architecture: {architecture}
|
| 175 |
-
"""
|
| 176 |
if lora_stats['layers_processed'] == 0:
|
| 177 |
-
stats_msg += "
|
| 178 |
|
| 179 |
-
return True,
|
| 180 |
|
| 181 |
except Exception as e:
|
| 182 |
import traceback
|
|
@@ -324,14 +317,15 @@ for key in fp8_state:
|
|
| 324 |
if A.ndim == 2 and B.ndim == 2:
|
| 325 |
lora_weight = B @ A
|
| 326 |
else:
|
| 327 |
-
#
|
| 328 |
-
lora_weight =
|
|
|
|
| 329 |
reconstructed[key] = fp8_state[key].to(torch.float32) + lora_weight
|
| 330 |
else:
|
| 331 |
reconstructed[key] = fp8_state[key].to(torch.float32)
|
| 332 |
```
|
| 333 |
|
| 334 |
-
> Requires PyTorch β₯ 2.1 for FP8 support. Use
|
| 335 |
"""
|
| 336 |
|
| 337 |
with open(os.path.join(output_dir, "README.md"), "w") as f:
|
|
|
|
| 19 |
|
| 20 |
def low_rank_decomposition(weight, rank=64):
|
| 21 |
"""
|
| 22 |
+
Correct LoRA decomposition supporting 2D and 4D tensors.
|
| 23 |
+
Returns (lora_A, lora_B) such that weight β lora_B @ lora_A for 2D,
|
| 24 |
+
or appropriate conv form for 4D.
|
| 25 |
"""
|
| 26 |
original_shape = weight.shape
|
| 27 |
original_dtype = weight.dtype
|
|
|
|
| 35 |
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
|
| 36 |
S_sqrt = torch.sqrt(S[:actual_rank])
|
| 37 |
|
| 38 |
+
# Standard LoRA factorization: W β W_B @ W_A
|
| 39 |
W_A = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous() # [rank, in_features]
|
| 40 |
W_B = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous() # [out_features, rank]
|
| 41 |
|
|
|
|
| 43 |
|
| 44 |
elif weight.ndim == 4:
|
| 45 |
out_ch, in_ch, k_h, k_w = weight.shape
|
| 46 |
+
if k_h * k_w <= 9: # small conv kernels (e.g., 3x3)
|
| 47 |
+
# Reshape to 2D: [out_ch, in_ch * k_h * k_w]
|
| 48 |
+
weight_2d = weight.view(out_ch, -1)
|
| 49 |
actual_rank = min(rank, min(weight_2d.shape) // 2)
|
| 50 |
if actual_rank < 4:
|
| 51 |
return None, None
|
|
|
|
| 56 |
W_A_2d = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous()
|
| 57 |
W_B_2d = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous()
|
| 58 |
|
| 59 |
+
# Reshape back to conv format
|
| 60 |
+
W_A = W_A_2d.view(actual_rank, in_ch, k_h, k_w).contiguous()
|
| 61 |
W_B = W_B_2d.view(out_ch, actual_rank, 1, 1).contiguous()
|
| 62 |
|
| 63 |
return W_A.to(original_dtype), W_B.to(original_dtype)
|
|
|
|
| 65 |
return None, None
|
| 66 |
|
| 67 |
except Exception as e:
|
| 68 |
+
print(f"Decomposition error for {original_shape}: {e}")
|
| 69 |
return None, None
|
| 70 |
|
| 71 |
def should_apply_lora(key, weight, architecture="auto"):
|
| 72 |
+
"""Architecture-aware LoRA eligibility."""
|
| 73 |
lower_key = key.lower()
|
| 74 |
|
| 75 |
+
# Skip bias, norm, and tiny tensors
|
| 76 |
if 'bias' in lower_key or 'norm' in lower_key or weight.numel() < 256:
|
| 77 |
return False
|
| 78 |
|
| 79 |
if architecture == "text_encoder":
|
| 80 |
+
return any(t in lower_key for t in ['emb', 'embed', 'attn', 'mlp'])
|
| 81 |
elif architecture == "unet_transformer":
|
| 82 |
+
return any(t in lower_key for t in ['attn', 'transformer', 'to_q', 'to_k', 'to_v', 'to_out'])
|
| 83 |
elif architecture == "unet_conv":
|
| 84 |
+
return any(t in lower_key for t in ['conv', 'resnet', 'down', 'up', 'skip'])
|
| 85 |
elif architecture == "vae":
|
| 86 |
+
return any(t in lower_key for t in ['encoder', 'decoder', 'quant', 'post_quant', 'pre_quant'])
|
| 87 |
else: # "auto" or "all"
|
| 88 |
return weight.ndim in [2, 4]
|
| 89 |
|
|
|
|
| 115 |
|
| 116 |
lora_stats = {
|
| 117 |
'total_layers': total,
|
|
|
|
| 118 |
'layers_eligible': 0,
|
| 119 |
'layers_processed': 0,
|
| 120 |
'layers_skipped': [],
|
|
|
|
| 123 |
for i, key in enumerate(state_dict):
|
| 124 |
progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
|
| 125 |
weight = state_dict[key]
|
|
|
|
| 126 |
|
| 127 |
if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
| 128 |
fp8_weight = weight.to(fp8_dtype)
|
| 129 |
sd_fp8[key] = fp8_weight
|
| 130 |
|
|
|
|
| 131 |
if should_apply_lora(key, weight, architecture):
|
| 132 |
lora_stats['layers_eligible'] += 1
|
| 133 |
|
|
|
|
| 139 |
lora_keys.append(key)
|
| 140 |
lora_stats['layers_processed'] += 1
|
| 141 |
else:
|
| 142 |
+
lora_stats['layers_skipped'].append(f"{key}: decomposition failed")
|
| 143 |
except Exception as e:
|
| 144 |
+
lora_stats['layers_skipped'].append(f"{key}: exception: {e}")
|
| 145 |
else:
|
| 146 |
+
reason = "filtered by architecture" if architecture != "auto" else "not 2D/4D or too small"
|
| 147 |
lora_stats['layers_skipped'].append(f"{key}: skipped ({reason})")
|
| 148 |
else:
|
| 149 |
sd_fp8[key] = weight
|
|
|
|
| 154 |
lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
|
| 155 |
|
| 156 |
save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
|
|
|
|
|
|
|
| 157 |
save_file(lora_weights, lora_path, metadata={
|
| 158 |
"format": "pt",
|
| 159 |
"lora_rank": str(lora_rank),
|
|
|
|
| 164 |
progress(0.9, desc="Saved FP8 and LoRA files.")
|
| 165 |
progress(1.0, desc="β
FP8 + LoRA extraction complete!")
|
| 166 |
|
| 167 |
+
stats_msg = f"FP8 ({fp8_format}) and rank-{lora_rank} LoRA ({architecture}) saved.\n"
|
| 168 |
+
stats_msg += f"Processed {lora_stats['layers_processed']}/{lora_stats['layers_eligible']} eligible layers."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
if lora_stats['layers_processed'] == 0:
|
| 170 |
+
stats_msg += " β οΈ No valid LoRA weights generated."
|
| 171 |
|
| 172 |
+
return True, stats_msg, lora_stats
|
| 173 |
|
| 174 |
except Exception as e:
|
| 175 |
import traceback
|
|
|
|
| 317 |
if A.ndim == 2 and B.ndim == 2:
|
| 318 |
lora_weight = B @ A
|
| 319 |
else:
|
| 320 |
+
# Conv LoRA: simplified reconstruction
|
| 321 |
+
lora_weight = F.conv2d(fp8_state[key].unsqueeze(0).to(torch.float32), A, groups=1)[:, :B.shape[0]]
|
| 322 |
+
lora_weight = lora_weight.squeeze(0) + F.conv2d(fp8_state[key].unsqueeze(0).to(torch.float32), B, groups=1).squeeze(0)
|
| 323 |
reconstructed[key] = fp8_state[key].to(torch.float32) + lora_weight
|
| 324 |
else:
|
| 325 |
reconstructed[key] = fp8_state[key].to(torch.float32)
|
| 326 |
```
|
| 327 |
|
| 328 |
+
> Requires PyTorch β₯ 2.1 for FP8 support. Use matching architecture during inference.
|
| 329 |
"""
|
| 330 |
|
| 331 |
with open(os.path.join(output_dir, "README.md"), "w") as f:
|