anyMODE commited on
Commit
faa1b64
·
verified ·
1 Parent(s): 59d2585

Upload extract_lora.py

Browse files
Files changed (1) hide show
  1. extract_lora.py +139 -0
extract_lora.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from safetensors.torch import save_file, safe_open
4
+ from tqdm import tqdm
5
+ import sys
6
+
7
+
8
+ def get_torch_dtype(dtype_str: str):
9
+ """Converts a string to a torch.dtype object."""
10
+ if dtype_str == "fp32":
11
+ return torch.float32
12
+ if dtype_str == "fp16":
13
+ return torch.float16
14
+ if dtype_str == "bf16":
15
+ return torch.bfloat16
16
+ raise ValueError(f"Unsupported dtype: {dtype_str}")
17
+
18
+
19
+ def extract_and_svd_lora(model_a_path: str, model_b_path: str, output_path: str, rank: int, device: str, alpha: float,
20
+ dtype: torch.dtype):
21
+ """
22
+ Extracts the difference between two models, applies SVD to reduce the rank,
23
+ and saves the result as a LoRA file.
24
+ """
25
+ print(f"Loading base model A: {model_a_path}")
26
+ print(f"Loading finetuned model B: {model_b_path}")
27
+
28
+ lora_tensors = {}
29
+
30
+ with safe_open(model_a_path, framework="pt", device="cpu") as f_a, \
31
+ safe_open(model_b_path, framework="pt", device="cpu") as f_b:
32
+
33
+ keys_a = set(f_a.keys())
34
+ keys_b = set(f_b.keys())
35
+ common_keys = keys_a.intersection(keys_b)
36
+
37
+ # Filter for processable layers (typically linear and conv weights)
38
+ # We exclude biases and non-weight tensors.
39
+ weight_keys = {k for k in common_keys if k.endswith('.weight') and 'lora_' not in k}
40
+
41
+ if not weight_keys:
42
+ print("No common weight keys found between the two models. Exiting.")
43
+ sys.exit(1)
44
+
45
+ print(f"Found {len(weight_keys)} common weight keys to process.")
46
+
47
+ # Main processing loop with progress bar
48
+ for key in tqdm(sorted(list(weight_keys)), desc="Processing Layers"):
49
+ try:
50
+ # Load tensors and move to the selected device and dtype
51
+ tensor_a = f_a.get_tensor(key).to(device=device, dtype=dtype)
52
+ tensor_b = f_b.get_tensor(key).to(device=device, dtype=dtype)
53
+
54
+ if tensor_a.shape != tensor_b.shape:
55
+ print(f"Skipping key {key} due to shape mismatch: A={tensor_a.shape}, B={tensor_b.shape}")
56
+ continue
57
+
58
+ # Calculate the difference (delta weight)
59
+ delta_w = tensor_b - tensor_a
60
+
61
+ # SVD works on 2D matrices. Reshape conv layers and other ND tensors.
62
+ original_shape = delta_w.shape
63
+ if delta_w.dim() > 2:
64
+ delta_w = delta_w.view(original_shape[0], -1)
65
+
66
+ # --- Core SVD Logic ---
67
+ # ΔW ≈ U * S * Vh
68
+ # U: Left singular vectors
69
+ # S: Singular values (a 1D vector)
70
+ # Vh: Right singular vectors (transposed)
71
+ U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False)
72
+
73
+ # Truncate to the desired rank
74
+ current_rank = min(rank, S.size(0)) # Ensure rank is not > possible rank
75
+ U = U[:, :current_rank]
76
+ S = S[:current_rank]
77
+ Vh = Vh[:current_rank, :]
78
+
79
+ # --- Decompose into LoRA A and B matrices ---
80
+ # LoRA A (lora_down) is Vh
81
+ # LoRA B (lora_up) is U * S
82
+ # We scale lora_up by the singular values to retain the magnitude
83
+ lora_down = Vh
84
+ lora_up = U @ torch.diag(S)
85
+
86
+ # Reshape back to original conv format if necessary
87
+ if len(original_shape) > 2:
88
+ # For Conv2D, lora_down is (rank, in_channels * k_h * k_w)
89
+ # and lora_up is (out_channels, rank). No reshape needed for up.
90
+ pass # The matrix form is standard for LoRA conv layers
91
+
92
+ # Create LoRA tensor names
93
+ base_name = key.replace('.weight', '')
94
+ lora_down_name = f"{base_name}.lora_down.weight"
95
+ lora_up_name = f"{base_name}.lora_up.weight"
96
+ alpha_name = f"{base_name}.alpha"
97
+
98
+ # Store tensors, moving them to CPU for saving
99
+ lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32)
100
+ lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32)
101
+ lora_tensors[alpha_name] = torch.tensor(alpha).to(torch.float32)
102
+
103
+ except Exception as e:
104
+ print(f"Failed to process key {key}: {e}")
105
+
106
+ # Save the final LoRA file
107
+ if not lora_tensors:
108
+ print("No tensors were processed. Output file will not be created.")
109
+ return
110
+
111
+ print(f"\nSaving {len(lora_tensors)} tensors to {output_path}...")
112
+ save_file(lora_tensors, output_path)
113
+ print("✅ Done!")
114
+
115
+
116
+ if __name__ == "__main__":
117
+ parser = argparse.ArgumentParser(description="Extract and SVD a LoRA from two SafeTensors checkpoints.")
118
+
119
+ parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint in .safetensors format.")
120
+ parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint in .safetensors format.")
121
+ parser.add_argument("output", type=str, help="Path to save the output LoRA file in .safetensors format.")
122
+
123
+ parser.add_argument("--rank", type=int, required=True, help="The target rank for the SVD.")
124
+ parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
125
+ help="Device to use for computation ('cuda' or 'cpu').")
126
+ parser.add_argument("--alpha", type=float, default=1.0, help="The alpha (scaling) factor for the LoRA.")
127
+ parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"],
128
+ help="Precision to use for calculations.")
129
+
130
+ args = parser.parse_args()
131
+
132
+ # Device check
133
+ if args.device == "cuda" and not torch.cuda.is_available():
134
+ print("CUDA is not available. Falling back to CPU.")
135
+ args.device = "cpu"
136
+
137
+ dtype = get_torch_dtype(args.precision)
138
+
139
+ extract_and_svd_lora(args.model_a, args.model_b, args.output, args.rank, args.device, args.alpha, dtype)