本博客主要介绍了
1.如何冻结模型的部分参数不更新(不参与训练)
2. 部分冻结的模型,加载时如何加载。
3. 如何拓展模型,在现有模型的基础上,拓展训练参数,进行增量训练
class MyModel(nn.Module):
def __init__(self, num_user, pretrain):
super().__init__()
self.bert = BertModel.from_pretrained(pretrain)
# 冻结self.bert的参数
for param in self.bert.parameters():
param.requires_grad = False
self.text_pooler = nn.Linear(768, 768)
self.user_emb = nn.Embedding(num_user, 768)
self.linear1 = nn.Linear(768, 768*3)
self.drop = nn.Dropout(0.3)
self.linear2 = nn.Linear(768*3, 768)
def extend_user(self, n):
"""拓展参数"""
num_user = self.user_emb.num_embeddings
new_users = n + num_user
extended_user_emb = nn.Embedding(new_users, 768)
extended_user_emb.weight.data[:num_user] = self.user_emb.weight.data
self.user_emb = extended_user_emb
def save(self, path):
"""保存的时候也不要保存bert的参数,只保存我训练的参数,模型体积大大减小"""
model_state_dict = self.state_dict()
non_bert_params = {key: value for key, value in model_state_dict.items() if 'bert' not in key}
torch.save(non_bert_params, path)
def load(self, path):
"""这里我们只加载了非bert的参数,而bert的参数在init方法中加载了"""
file_path = os.path.join(path, 'model.bin')
# 加载之前保存的非bert参数
non_bert_params = torch.load(file_path)
# 将非bert参数加载到模型中
self.load_state_dict(non_bert_params, strict=False)