codemichaeld commited on
Commit
1fd8c55
Β·
verified Β·
1 Parent(s): dd07a59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -33
app.py CHANGED
@@ -19,8 +19,9 @@ except ImportError:
19
 
20
  def low_rank_decomposition(weight, rank=64):
21
  """
22
- Improved LoRA decomposition supporting 2D and 4D tensors with proper factorization.
23
- Returns (lora_A, lora_B) such that weight β‰ˆ lora_B @ lora_A (for 2D) or appropriate conv form.
 
24
  """
25
  original_shape = weight.shape
26
  original_dtype = weight.dtype
@@ -34,7 +35,7 @@ def low_rank_decomposition(weight, rank=64):
34
  U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
35
  S_sqrt = torch.sqrt(S[:actual_rank])
36
 
37
- # Standard LoRA: W β‰ˆ W_B @ W_A
38
  W_A = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous() # [rank, in_features]
39
  W_B = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous() # [out_features, rank]
40
 
@@ -42,8 +43,9 @@ def low_rank_decomposition(weight, rank=64):
42
 
43
  elif weight.ndim == 4:
44
  out_ch, in_ch, k_h, k_w = weight.shape
45
- if k_h * k_w <= 9: # small kernel (e.g., 3x3)
46
- weight_2d = weight.permute(0, 2, 3, 1).reshape(out_ch, -1)
 
47
  actual_rank = min(rank, min(weight_2d.shape) // 2)
48
  if actual_rank < 4:
49
  return None, None
@@ -54,7 +56,8 @@ def low_rank_decomposition(weight, rank=64):
54
  W_A_2d = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous()
55
  W_B_2d = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous()
56
 
57
- W_A = W_A_2d.view(actual_rank, k_h, k_w, in_ch).permute(0, 3, 1, 2).contiguous()
 
58
  W_B = W_B_2d.view(out_ch, actual_rank, 1, 1).contiguous()
59
 
60
  return W_A.to(original_dtype), W_B.to(original_dtype)
@@ -62,25 +65,25 @@ def low_rank_decomposition(weight, rank=64):
62
  return None, None
63
 
64
  except Exception as e:
65
- print(f"Decomposition failed for {original_shape}: {e}")
66
  return None, None
67
 
68
  def should_apply_lora(key, weight, architecture="auto"):
69
- """Determine if LoRA should be applied based on architecture selection."""
70
  lower_key = key.lower()
71
 
72
- # Skip unimportant weights
73
  if 'bias' in lower_key or 'norm' in lower_key or weight.numel() < 256:
74
  return False
75
 
76
  if architecture == "text_encoder":
77
- return any(t in lower_key for t in ['emb', 'embed', 'attn'])
78
  elif architecture == "unet_transformer":
79
- return any(t in lower_key for t in ['attn', 'transformer'])
80
  elif architecture == "unet_conv":
81
- return any(t in lower_key for t in ['conv', 'resnet', 'down', 'up'])
82
  elif architecture == "vae":
83
- return any(t in lower_key for t in ['encoder', 'decoder', 'quant'])
84
  else: # "auto" or "all"
85
  return weight.ndim in [2, 4]
86
 
@@ -112,7 +115,6 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
112
 
113
  lora_stats = {
114
  'total_layers': total,
115
- 'layers_analyzed': 0,
116
  'layers_eligible': 0,
117
  'layers_processed': 0,
118
  'layers_skipped': [],
@@ -121,13 +123,11 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
121
  for i, key in enumerate(state_dict):
122
  progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
123
  weight = state_dict[key]
124
- lora_stats['layers_analyzed'] += 1
125
 
126
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
127
  fp8_weight = weight.to(fp8_dtype)
128
  sd_fp8[key] = fp8_weight
129
 
130
- # Apply LoRA based on architecture selection
131
  if should_apply_lora(key, weight, architecture):
132
  lora_stats['layers_eligible'] += 1
133
 
@@ -139,11 +139,11 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
139
  lora_keys.append(key)
140
  lora_stats['layers_processed'] += 1
141
  else:
142
- lora_stats['layers_skipped'].append(f"{key}: decomposition returned None")
143
  except Exception as e:
144
- lora_stats['layers_skipped'].append(f"{key}: {str(e)}")
145
  else:
146
- reason = "architecture filter" if architecture != "auto" else "not 2D/4D or too small"
147
  lora_stats['layers_skipped'].append(f"{key}: skipped ({reason})")
148
  else:
149
  sd_fp8[key] = weight
@@ -154,8 +154,6 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
154
  lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
155
 
156
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
157
-
158
- # Always save LoRA file (even if empty) with stats
159
  save_file(lora_weights, lora_path, metadata={
160
  "format": "pt",
161
  "lora_rank": str(lora_rank),
@@ -166,17 +164,12 @@ def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_forma
166
  progress(0.9, desc="Saved FP8 and LoRA files.")
167
  progress(1.0, desc="βœ… FP8 + LoRA extraction complete!")
168
 
169
- stats_msg = f"""
170
- πŸ“Š LoRA Extraction Stats:
171
- - Total layers: {lora_stats['total_layers']}
172
- - Eligible for LoRA: {lora_stats['layers_eligible']}
173
- - Successfully processed: {lora_stats['layers_processed']}
174
- - Architecture: {architecture}
175
- """
176
  if lora_stats['layers_processed'] == 0:
177
- stats_msg += "\n⚠️ No LoRA weights generated. Try lower rank or different architecture."
178
 
179
- return True, f"FP8 ({fp8_format}) and rank-{lora_rank} LoRA saved.\n{stats_msg}", lora_stats
180
 
181
  except Exception as e:
182
  import traceback
@@ -324,14 +317,15 @@ for key in fp8_state:
324
  if A.ndim == 2 and B.ndim == 2:
325
  lora_weight = B @ A
326
  else:
327
- # Handle convolutional LoRA (simplified)
328
- lora_weight = torch.zeros_like(fp8_state[key], dtype=torch.float32)
 
329
  reconstructed[key] = fp8_state[key].to(torch.float32) + lora_weight
330
  else:
331
  reconstructed[key] = fp8_state[key].to(torch.float32)
332
  ```
333
 
334
- > Requires PyTorch β‰₯ 2.1 for FP8 support. Use the same architecture selection during inference.
335
  """
336
 
337
  with open(os.path.join(output_dir, "README.md"), "w") as f:
 
19
 
20
  def low_rank_decomposition(weight, rank=64):
21
  """
22
+ Correct LoRA decomposition supporting 2D and 4D tensors.
23
+ Returns (lora_A, lora_B) such that weight β‰ˆ lora_B @ lora_A for 2D,
24
+ or appropriate conv form for 4D.
25
  """
26
  original_shape = weight.shape
27
  original_dtype = weight.dtype
 
35
  U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
36
  S_sqrt = torch.sqrt(S[:actual_rank])
37
 
38
+ # Standard LoRA factorization: W β‰ˆ W_B @ W_A
39
  W_A = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous() # [rank, in_features]
40
  W_B = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous() # [out_features, rank]
41
 
 
43
 
44
  elif weight.ndim == 4:
45
  out_ch, in_ch, k_h, k_w = weight.shape
46
+ if k_h * k_w <= 9: # small conv kernels (e.g., 3x3)
47
+ # Reshape to 2D: [out_ch, in_ch * k_h * k_w]
48
+ weight_2d = weight.view(out_ch, -1)
49
  actual_rank = min(rank, min(weight_2d.shape) // 2)
50
  if actual_rank < 4:
51
  return None, None
 
56
  W_A_2d = (Vh[:actual_rank, :] * S_sqrt.unsqueeze(1)).contiguous()
57
  W_B_2d = (U[:, :actual_rank] * S_sqrt.unsqueeze(0)).contiguous()
58
 
59
+ # Reshape back to conv format
60
+ W_A = W_A_2d.view(actual_rank, in_ch, k_h, k_w).contiguous()
61
  W_B = W_B_2d.view(out_ch, actual_rank, 1, 1).contiguous()
62
 
63
  return W_A.to(original_dtype), W_B.to(original_dtype)
 
65
  return None, None
66
 
67
  except Exception as e:
68
+ print(f"Decomposition error for {original_shape}: {e}")
69
  return None, None
70
 
71
  def should_apply_lora(key, weight, architecture="auto"):
72
+ """Architecture-aware LoRA eligibility."""
73
  lower_key = key.lower()
74
 
75
+ # Skip bias, norm, and tiny tensors
76
  if 'bias' in lower_key or 'norm' in lower_key or weight.numel() < 256:
77
  return False
78
 
79
  if architecture == "text_encoder":
80
+ return any(t in lower_key for t in ['emb', 'embed', 'attn', 'mlp'])
81
  elif architecture == "unet_transformer":
82
+ return any(t in lower_key for t in ['attn', 'transformer', 'to_q', 'to_k', 'to_v', 'to_out'])
83
  elif architecture == "unet_conv":
84
+ return any(t in lower_key for t in ['conv', 'resnet', 'down', 'up', 'skip'])
85
  elif architecture == "vae":
86
+ return any(t in lower_key for t in ['encoder', 'decoder', 'quant', 'post_quant', 'pre_quant'])
87
  else: # "auto" or "all"
88
  return weight.ndim in [2, 4]
89
 
 
115
 
116
  lora_stats = {
117
  'total_layers': total,
 
118
  'layers_eligible': 0,
119
  'layers_processed': 0,
120
  'layers_skipped': [],
 
123
  for i, key in enumerate(state_dict):
124
  progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
125
  weight = state_dict[key]
 
126
 
127
  if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
128
  fp8_weight = weight.to(fp8_dtype)
129
  sd_fp8[key] = fp8_weight
130
 
 
131
  if should_apply_lora(key, weight, architecture):
132
  lora_stats['layers_eligible'] += 1
133
 
 
139
  lora_keys.append(key)
140
  lora_stats['layers_processed'] += 1
141
  else:
142
+ lora_stats['layers_skipped'].append(f"{key}: decomposition failed")
143
  except Exception as e:
144
+ lora_stats['layers_skipped'].append(f"{key}: exception: {e}")
145
  else:
146
+ reason = "filtered by architecture" if architecture != "auto" else "not 2D/4D or too small"
147
  lora_stats['layers_skipped'].append(f"{key}: skipped ({reason})")
148
  else:
149
  sd_fp8[key] = weight
 
154
  lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
155
 
156
  save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
 
 
157
  save_file(lora_weights, lora_path, metadata={
158
  "format": "pt",
159
  "lora_rank": str(lora_rank),
 
164
  progress(0.9, desc="Saved FP8 and LoRA files.")
165
  progress(1.0, desc="βœ… FP8 + LoRA extraction complete!")
166
 
167
+ stats_msg = f"FP8 ({fp8_format}) and rank-{lora_rank} LoRA ({architecture}) saved.\n"
168
+ stats_msg += f"Processed {lora_stats['layers_processed']}/{lora_stats['layers_eligible']} eligible layers."
 
 
 
 
 
169
  if lora_stats['layers_processed'] == 0:
170
+ stats_msg += " ⚠️ No valid LoRA weights generated."
171
 
172
+ return True, stats_msg, lora_stats
173
 
174
  except Exception as e:
175
  import traceback
 
317
  if A.ndim == 2 and B.ndim == 2:
318
  lora_weight = B @ A
319
  else:
320
+ # Conv LoRA: simplified reconstruction
321
+ lora_weight = F.conv2d(fp8_state[key].unsqueeze(0).to(torch.float32), A, groups=1)[:, :B.shape[0]]
322
+ lora_weight = lora_weight.squeeze(0) + F.conv2d(fp8_state[key].unsqueeze(0).to(torch.float32), B, groups=1).squeeze(0)
323
  reconstructed[key] = fp8_state[key].to(torch.float32) + lora_weight
324
  else:
325
  reconstructed[key] = fp8_state[key].to(torch.float32)
326
  ```
327
 
328
+ > Requires PyTorch β‰₯ 2.1 for FP8 support. Use matching architecture during inference.
329
  """
330
 
331
  with open(os.path.join(output_dir, "README.md"), "w") as f: