pytorch:加载预训练模型中的部分参数,并固定该部分参数
https://www.jianshu.com/p/d67d62982a24
initial_cnn = models.densenet121(pretrained=False)
self.cnn = torch.nn.Sequential(*(list(initial_cnn.children())[:-1]))
device = torch.device('cuda')
cnn = self.cnn.to(device)
cnn = nn.DataParallel(cnn)
model_dict = cnn.state_dict()
pretrained_dict = torch.load('epoch_11.pth')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
cnn.load_state_dict(model_dict)