import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Loar_LINER(nn.Module):
def __init__(self, infeature, outfeature, merge, rank=16, loar_alpha=16, dropout=0.5):
super(Loar_LINER, self).__init__()
self.infeature = infeature
self.outfeature = outfeature
self.merge = merge
self.rank = rank
self.dropout = dropout
self.loar_alpha = loar_alpha
self.linear = nn.Linear(infeature, outfeature)
if rank > 0:
self.loar_b = nn.Parameter(torch.zeros(outfeature, rank))
self.loar_a = nn.Parameter(torch.zeros(rank, infeature))
self.scale = self.loar_alpha / self.rank
self.linear.weight.requires_grad = False
if self.dropout >= 0:
self.dropout = nn.Dropout(self.dropout)
else:
self.dropout = nn.Identity()
self.initial_weight()
def initial_weight(self):
nn.init.kaiming_uniform_(self.loar_a, a=math.sqrt(5))
nn.init.zeros_(self.loar_b)
def forward(self, x):
if self.rank > 0 and self.merge:
output = F.linear(x, self.linear.weight + self.loar_b @ self.loar_a * self.scale, self.linear.bias)
else:
output = self.linear(x)
return self.dropout(output)
lora 简单代码复现
最新推荐文章于 2024-09-14 19:18:15 发布