import torch
import torch.nn as nn
import torch.nn.functional as F
class LoRALinear(nn.Module):
def __init__(self,in_features,out_features,merge,rank,lora_alpha,dropout):
super(LoRALinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.merge = merge
self.rank = rank
self.dropout = dropout
self.lora_alpha = lora_alpha
self.linear = nn.linear(in_features, out_features)
if rank > 0 :
self.lora_a = nn.Parameter(torch.zeros(rank, in_features))
self.lora_b = nn.Parameter(torch.zeros(out_features, rank))
self.scale = self.lora_alpha / self.rank #这是lora微调矩阵最终加到W0的时候的权重
# Freezing the pre-trained weight matrix
self.linear.weight.requires_grad = False
if self.dropout>0:
self.dropout = nn.Dropout(self.dropout)
else:
self.dropout = lambda x: x
self.initial_weights()
def initial_weights(self):
nn.init.normal_(self.lora_a) # 默认0,1正态分布
nn.init.zeros_(self.lora_b)
def forward(self, x):
if self.rank > 0 and self.merge:
output = F.linear(x,self.linear.weight+self.lora_b@self.lora_a*self.scale,self.bias) #(W0+BA)X+b
output = self.dropout(output)
return output
else:
return self.dropout(self.linear(x))