比如说加载训练好的cnn模型,我把下面这段复制到新的模型中,再加载就没有报错了,不知道对不对。
class CNN(nn.Module):
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
tsne_feature = x
x = self.flatten(x)
return self.linear(x), tsne_feature
cnn=torch.load('model.pkl')
Net类继承nn.Module,super(Net, self).__init__()就是对继承自父类nn.Module的属性进行初始
化。而且是用nn.Module的初始化方法来初始化继承的属性。
另外记录保存模型和加载模型的代码:
torch.save(cnn, 'model.pkl') # 保存整个模型
torch.save(cnn.state_dict(), 'model_params.pth') # 只保存网络中的参数
new_model = torch.load('./data/model.pkl') # 加载模型
torch.load_state_dict(torch.load('model_params.pth'))#提取所有的参数