问题:使用pytorch训练加载预训练模型会修改网络传入的参数,例如类别等,修改后的网络加载预训练网络则会出现维度不匹配的错误,例如如下
size mismatch for word_embeddings.weight: copying a param with shape torch.Size([3403, 128]) from checkpoint, the shape in current model is torch.Size([12386, 128]).
在不改变原来网络结构的条件下,通过修改预训练模型中的参数,修改维度大小,从而似其适合网络的加载。
解决:
def change_feature(check_point, num_class):
device = torch.device("cuda" if torch.cuda.is_available() else