Spaces:
Build error
Build error
| # ------------------------------------------------------------------------------------------ | |
| # Copyright (c) Microsoft Corporation. All rights reserved. | |
| # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. | |
| # ------------------------------------------------------------------------------------------ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from typing import Optional, List | |
| class LoRALayer(): | |
| def __init__( | |
| self, | |
| r: int, | |
| lora_alpha: int, | |
| lora_dropout: float, | |
| merge_weights: bool, | |
| ): | |
| self.r = r | |
| self.lora_alpha = lora_alpha | |
| # Optional dropout | |
| if lora_dropout > 0.: | |
| self.lora_dropout = nn.Dropout(p=lora_dropout) | |
| else: | |
| self.lora_dropout = lambda x: x | |
| # Mark the weight as unmerged | |
| self.merged = False | |
| self.merge_weights = merge_weights | |
| class LoRAEmbedding(nn.Embedding, LoRALayer): | |
| # LoRA implemented in a dense layer | |
| def __init__( | |
| self, | |
| num_embeddings: int, | |
| embedding_dim: int, | |
| r: int = 0, | |
| lora_alpha: int = 1, | |
| merge_weights: bool = True, | |
| **kwargs | |
| ): | |
| nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) | |
| LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, | |
| merge_weights=merge_weights) | |
| # Actual trainable parameters | |
| if r > 0: | |
| self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings))) | |
| self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r))) | |
| self.scaling = self.lora_alpha / self.r | |
| # Freezing the pre-trained weight matrix | |
| self.weight.requires_grad = False | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| nn.Embedding.reset_parameters(self) | |
| if hasattr(self, 'lora_A'): | |
| # initialize A the same way as the default for nn.Linear and B to zero | |
| nn.init.zeros_(self.lora_A) | |
| nn.init.normal_(self.lora_B) | |
| def train(self, mode: bool = True): | |
| nn.Embedding.train(self, mode) | |
| if mode: | |
| if self.merge_weights and self.merged: | |
| # Make sure that the weights are not merged | |
| if self.r > 0: | |
| self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling | |
| self.merged = False | |
| else: | |
| if self.merge_weights and not self.merged: | |
| # Merge the weights and mark it | |
| if self.r > 0: | |
| self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling | |
| self.merged = True | |
| def forward(self, x: torch.Tensor): | |
| if self.r > 0 and not self.merged: | |
| result = nn.Embedding.forward(self, x) | |
| after_A = F.embedding( | |
| x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm, | |
| self.norm_type, self.scale_grad_by_freq, self.sparse | |
| ) | |
| result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling | |
| return result | |
| else: | |
| return nn.Embedding.forward(self, x) | |
| class LoRALinear(nn.Linear, LoRALayer): | |
| # LoRA implemented in a dense layer | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| r: int = 0, | |
| lora_alpha: int = 1, | |
| lora_dropout: float = 0., | |
| fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) | |
| merge_weights: bool = True, | |
| **kwargs | |
| ): | |
| nn.Linear.__init__(self, in_features, out_features, **kwargs) | |
| LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, | |
| merge_weights=merge_weights) | |
| self.fan_in_fan_out = fan_in_fan_out | |
| # Actual trainable parameters | |
| if r > 0: | |
| self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) | |
| self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) | |
| self.scaling = self.lora_alpha / self.r | |
| # Freezing the pre-trained weight matrix | |
| self.weight.requires_grad = False | |
| self.reset_parameters() | |
| if fan_in_fan_out: | |
| self.weight.data = self.weight.data.transpose(0, 1) | |
| def reset_parameters(self): | |
| nn.Linear.reset_parameters(self) | |
| if hasattr(self, 'lora_A'): | |
| # initialize B the same way as the default for nn.Linear and A to zero | |
| # this is different than what is described in the paper but should not affect performance | |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) | |
| nn.init.zeros_(self.lora_B) | |
| def train(self, mode: bool = True): | |
| def T(w): | |
| return w.transpose(0, 1) if self.fan_in_fan_out else w | |
| nn.Linear.train(self, mode) | |
| if mode: | |
| if self.merge_weights and self.merged: | |
| # Make sure that the weights are not merged | |
| if self.r > 0: | |
| self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling | |
| self.merged = False | |
| else: | |
| if self.merge_weights and not self.merged: | |
| # Merge the weights and mark it | |
| if self.r > 0: | |
| self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling | |
| self.merged = True | |
| def forward(self, x: torch.Tensor): | |
| def T(w): | |
| return w.transpose(0, 1) if self.fan_in_fan_out else w | |
| if self.r > 0 and not self.merged: | |
| result = F.linear(x, T(self.weight), bias=self.bias) | |
| result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling | |
| return result | |
| else: | |
| return F.linear(x, T(self.weight), bias=self.bias) | |
| class MergedLoRALinear(nn.Linear, LoRALayer): | |
| # LoRA implemented in a dense layer | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| r: int = 0, | |
| lora_alpha: int = 1, | |
| lora_dropout: float = 0., | |
| enable_lora: List[bool] = [False], | |
| fan_in_fan_out: bool = False, | |
| merge_weights: bool = True, | |
| **kwargs | |
| ): | |
| nn.Linear.__init__(self, in_features, out_features, **kwargs) | |
| LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, | |
| merge_weights=merge_weights) | |
| assert out_features % len(enable_lora) == 0, \ | |
| 'The length of enable_lora must divide out_features' | |
| self.enable_lora = enable_lora | |
| self.fan_in_fan_out = fan_in_fan_out | |
| # Actual trainable parameters | |
| if r > 0 and any(enable_lora): | |
| self.lora_A = nn.Parameter( | |
| self.weight.new_zeros((r * sum(enable_lora), in_features))) | |
| self.lora_B = nn.Parameter( | |
| self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) | |
| ) # weights for Conv1D with groups=sum(enable_lora) | |
| self.scaling = self.lora_alpha / self.r | |
| # Freezing the pre-trained weight matrix | |
| self.weight.requires_grad = False | |
| # Compute the indices | |
| self.lora_ind = self.weight.new_zeros( | |
| (out_features, ), dtype=torch.bool | |
| ).view(len(enable_lora), -1) | |
| self.lora_ind[enable_lora, :] = True | |
| self.lora_ind = self.lora_ind.view(-1) | |
| self.reset_parameters() | |
| if fan_in_fan_out: | |
| self.weight.data = self.weight.data.transpose(0, 1) | |
| def reset_parameters(self): | |
| nn.Linear.reset_parameters(self) | |
| if hasattr(self, 'lora_A'): | |
| # initialize A the same way as the default for nn.Linear and B to zero | |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) | |
| nn.init.zeros_(self.lora_B) | |
| def zero_pad(self, x): | |
| result = x.new_zeros((len(self.lora_ind), *x.shape[1:])) | |
| result[self.lora_ind] = x | |
| return result | |
| def merge_AB(self): | |
| def T(w): | |
| return w.transpose(0, 1) if self.fan_in_fan_out else w | |
| delta_w = F.conv1d( | |
| self.lora_A.unsqueeze(0), | |
| self.lora_B.unsqueeze(-1), | |
| groups=sum(self.enable_lora) | |
| ).squeeze(0) | |
| return T(self.zero_pad(delta_w)) | |
| def train(self, mode: bool = True): | |
| def T(w): | |
| return w.transpose(0, 1) if self.fan_in_fan_out else w | |
| nn.Linear.train(self, mode) | |
| if mode: | |
| if self.merge_weights and self.merged: | |
| # Make sure that the weights are not merged | |
| if self.r > 0 and any(self.enable_lora): | |
| self.weight.data -= self.merge_AB() * self.scaling | |
| self.merged = False | |
| else: | |
| if self.merge_weights and not self.merged: | |
| # Merge the weights and mark it | |
| if self.r > 0 and any(self.enable_lora): | |
| self.weight.data += self.merge_AB() * self.scaling | |
| self.merged = True | |
| def forward(self, x: torch.Tensor): | |
| def T(w): | |
| return w.transpose(0, 1) if self.fan_in_fan_out else w | |
| if self.merged: | |
| return F.linear(x, T(self.weight), bias=self.bias) | |
| else: | |
| result = F.linear(x, T(self.weight), bias=self.bias) | |
| if self.r > 0: | |
| result += self.lora_dropout(x) @ T(self.merge_AB().T) * self.scaling | |
| return result | |
| class ConvLoRA(nn.Module, LoRALayer): | |
| def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): | |
| super(ConvLoRA, self).__init__() | |
| self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) | |
| self.weight = self.conv.weight | |
| self.bias = self.conv.bias | |
| LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) | |
| assert isinstance(kernel_size, int) | |
| # Actual trainable parameters | |
| if r > 0: | |
| self.lora_A = nn.Parameter( | |
| self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) | |
| ) | |
| self.lora_B = nn.Parameter( | |
| self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size)) | |
| ) | |
| self.scaling = self.lora_alpha / self.r | |
| # Freezing the pre-trained weight matrix | |
| self.conv.weight.requires_grad = False | |
| self.reset_parameters() | |
| self.merged = False | |
| def reset_parameters(self): | |
| self.conv.reset_parameters() | |
| if hasattr(self, 'lora_A'): | |
| # initialize A the same way as the default for nn.Linear and B to zero | |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) | |
| nn.init.zeros_(self.lora_B) | |
| def train(self, mode=True): | |
| super(ConvLoRA, self).train(mode) | |
| if mode: | |
| if self.merge_weights and self.merged: | |
| if self.r > 0: | |
| # Make sure that the weights are not merged | |
| self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling | |
| self.merged = False | |
| else: | |
| if self.merge_weights and not self.merged: | |
| if self.r > 0: | |
| # Merge the weights and mark it | |
| self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling | |
| self.merged = True | |
| def forward(self, x): | |
| if self.r > 0 and not self.merged: | |
| return self.conv._conv_forward( | |
| x, | |
| self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling, | |
| self.conv.bias | |
| ) | |
| return self.conv(x) | |
| class LoRAConv2d(ConvLoRA): | |
| def __init__(self, *args, **kwargs): | |
| """ | |
| 基类的self.conv.weight对于2D可行, 所以直接继承 | |
| 但是对于1d会多kernel_size倍, 对3D会少kernel_size倍 | |
| """ | |
| super(LoRAConv2d, self).__init__(nn.Conv2d, *args, **kwargs) | |
| class LoRAConv1d(ConvLoRA): | |
| def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): | |
| super(ConvLoRA, self).__init__() | |
| self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, **kwargs) | |
| self.weight = self.conv.weight | |
| self.bias = self.conv.bias | |
| LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) | |
| assert isinstance(kernel_size, int) | |
| # Actual trainable parameters | |
| if r > 0: | |
| self.lora_A = nn.Parameter( | |
| self.conv.weight.new_zeros((r * kernel_size, in_channels)) | |
| ) | |
| self.lora_B = nn.Parameter( | |
| self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size)) | |
| ) | |
| self.scaling = self.lora_alpha / self.r | |
| # Freezing the pre-trained weight matrix | |
| self.conv.weight.requires_grad = False | |
| self.reset_parameters() | |
| self.merged = False | |
| # Can Extend to other ones like this | |
| class LoRAConv3d(ConvLoRA): | |
| def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): | |
| super(ConvLoRA, self).__init__() | |
| self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, **kwargs) | |
| self.weight = self.conv.weight | |
| self.bias = self.conv.bias | |
| LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) | |
| assert isinstance(kernel_size, int) | |
| # Actual trainable parameters | |
| if r > 0: | |
| self.lora_A = nn.Parameter( | |
| self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size * kernel_size)) | |
| ) | |
| self.lora_B = nn.Parameter( | |
| self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size)) | |
| ) | |
| self.scaling = self.lora_alpha / self.r | |
| # Freezing the pre-trained weight matrix | |
| self.conv.weight.requires_grad = False | |
| self.reset_parameters() | |
| self.merged = False | |
| if __name__ == '__main__': | |
| conv = LoRAConv1d(3, 32, kernel_size=3, stride=1, r=8) | |
| conv.train() | |
| print(conv.merged) | |
| conv.eval() | |
| print(conv.merged) | |