Lora原理介绍
- 论文:https://arxiv.org/abs/2106.09685
Lora代码原理
本次利用peft进行lora微调,其实现lora的主要代码为
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct", device_map="auto", trust_remote_code=True)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
其中在get_peft_model函数中对模型进行了lora的实现,在这个例子中,可以通过print(model)可以看到
PeftModelForCausalLM(
(base_model): LoraModel(
(model): Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(151936, 1536)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=1536, bias=True)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=1536, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(k_proj): Linear(in_features=1536, out_features=256, bias=True)
(v_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=256, bias=True)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=256, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(o_proj): Linear(in_features=1536, out_features=1536, bias=False)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
(up_proj): Linear(in_features=1536, out_features=8960, bias=False)
(down_proj): Linear(in_features=8960, out_features=1536, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm()
(post_attention_layernorm): Qwen2RMSNorm()
)
)
(norm): Qwen2RMSNorm()
)
(lm_head): Linear(in_features=1536, out_features=151936, bias=False)
)
)
)
此时,q_proj、v_proj已变成了lora.Linear层,通过阅读源码,源码实现如下
https://github.com/huggingface/peft/blob/900f96c40ddebae9d76bed374c8baed60e8b34e9/src/peft/tuners/lora/layer.py#L374
为了清晰阅读实现的细节,这里我提取了关键部分,
import torch
from torch import nn
import math
class Lora_Linear(nn.Module):
def __init__(
self,
base_layer,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0
) -> None:
super().__init__()
in_features, out_features = base_layer.in_features, base_layer.out_features
if lora_dropout > 0.0:
self.dropout = nn.Dropout(p=lora_dropout)
else:
self.dropout = nn.Identity()
self.lora_A = nn.Linear(in_features, r, bias=False)
self.lora_B = nn.Linear(r, out_features, bias=False)
self.scaling = lora_alpha / r
self.reset_lora_parameters()
def reset_lora_parameters(self):
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
result = self.base_layer(x)
result = result + self.lora_B(self.lora_A(self.dropout(x))) * self.scaling
return result
那peft是如何知道在模型哪里添加这个lora层呢,其实是可以通过config的target_modules进行指定
例如
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct", device_map="auto", trust_remote_code=True)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, peft_config)
print(model)
这里输出的结果如下:
PeftModelForCausalLM(
(base_model): LoraModel(
(model): Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): Embedding(151936, 1536)
(layers): ModuleList(
(0-27): 28 x Qwen2DecoderLayer(
(self_attn): Qwen2SdpaAttention(
(q_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=1536, bias=True)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=1536, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(k_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=256, bias=True)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=256, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(v_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=256, bias=True)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=256, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(o_proj): lora.Linear(
(base_layer): Linear(in_features=1536, out_features=1536, bias=False)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=1536, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=1536, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
(rotary_emb): Qwen2RotaryEmbedding()
)
(mlp): Qwen2MLP(
(gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
(up_proj): Linear(in_features=1536, out_features=8960, bias=False)
(down_proj): Linear(in_features=8960, out_features=1536, bias=False)
(act_fn): SiLU()
)
(input_layernorm): Qwen2RMSNorm()
(post_attention_layernorm): Qwen2RMSNorm()
)
)
(norm): Qwen2RMSNorm()
)
(lm_head): Linear(in_features=1536, out_features=151936, bias=False)
)
)
)
可以看到kqvo全部进行lora替换
如果不传入target_modules(刚才的例子),则peft会采用默认值进行填充,此处的默认值见
https://github.com/huggingface/peft/blob/900f96c40ddebae9d76bed374c8baed60e8b34e9/src/peft/utils/constants.py#L78
LORA微调实践
这里提供一份代码微调qwen,可以更改qwen的身份信息。
https://github.com/yinpu/llm-tutorial/tree/main/1.Lora%E5%BE%AE%E8%B0%83qwen2
最终结果:
你可以自由训练并且替换成你想要qwen回答的作者信息