一段Lora代码示例:
class LoRALayer(nn.Module):
def __init__(self, input_dim, low_rank_dim):
super(LoRALayer, self).__init__()
self.U = nn.Linear(input_dim, low_rank_dim, bias=False)
self.B = nn.Linear(low_rank_dim, input_dim, bias=False)
def forward(self, x):
return x + self.B(self.U(x))
Lora的原理图如图所示
由此可知,Lora模型其实就相当于你在指定的两层大模型之间,新增了一层带计算的残差结构,这样的话可以利用打Patch的方式,去对原始模型进行功能扩展,而仅需要如下基础信息:
- Lora的模型参数
- Lora的相对于原始模型的配置位置
这种方式其实是改变了原始模型结构,虽然改动不大,但是对于一些基于Graph加速的编译器来说,会影响这些模型的编译。