保存模型或权重参数的后缀问题:
pytorch保存数据的格式为.t7文件或者.pth文件,或者.pkl格式,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。
来自:https://blog.csdn.net/weixin_43216883/article/details/89792312
两种方式:
(1)保存模型参数
#保存
torch.save( model.state_dict(), path)
#加载
the_model = CNN()
the_model.load_state_dict(torch.load(path))
这种方法在加载模型的时候,必须在代码中重新将CNN的结构重新建立一遍,才可以将保存好的参数(w和b)放在模型中,进行训练用。
下面是利用上述方法,加载保存好的模型的具体代码,实现随便一张图片的识别(利用的手写体的数据集),CNN结构是用的莫烦里面的结构。
from PIL import Image
import torch.nn as nn
import torch
import numpy as np
def img2vec_img(img):
#将jpg等格式的图片转为向量
im = Image.open(img)
im = im.resize((28,28))
tmp = np.array(im)
#vec = tmp.ravel()
return tmp
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.1)
)
self.out = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
output = self.out(x)
return output
#模型加载
the_model = CNN()
the_model.load_state_dict(torch.load('C:\\Users\\happy\\Desktop\\train\\model.pth'))
def detect(path):
#the_model.eval()
test_picture = img2vec_img(path)
data = torch.from_numpy(test_picture).type ( 'torch.FloatTensor' )
data = torch.unsqueeze(data, 0) #给torch更加一个维度,以便于训练
data = torch.unsqueeze(data, 0) #给torch更加一个维度,以便于训练
the_model(data)
_, pred = torch.max(the_model(data) , 1)
#print(pred.int())
print(pred.int())
detect('C:\\Users\\happy\\Desktop\\t.bmp')
(2)保存模型整体
#保存
torch.save(model,path)
#加载
the_model = torch.load(path)