手写数字识别 python pytorch 二 之手写数字预测
上一篇: 手写数字识别 python pytorch 一 之手写数字预测
直接在上一篇代码下复制或书写
本片和上一篇代码直接复制即可运行
注意
手写数字必须在图片中间,预测效果才好
预处理手写数字函数:
imread()
def shouxiefix(path_tu):
#path_tu = 'D:\\002.png'
#这个函数可以看我这篇文章的详细介绍 ref https://blog.csdn.net/weixin_44021118/article/details/107769561)
img = cv2.imread(path_tu,0)# 蓝95 绿200 红147通道数(-1) 数组 高 宽 通道459 362 3
img = cv2.resize(img,(28,28))
"""cv2.imshow('img',img)#查看读入图片
cv2.waitKey(0)#按任意键继续
cv2.destroyAllWindows()#关闭图片窗口"""
#gao=img.shape[0]
#kuan=img.shape[1]
gao=28#指定
kuan=28#
img=cv2.resize(img,(kuan,gao))#看图片上的 kuan gao
#tongdao=img.shape[2]
img = img.astype(np.float32)
#进行归一化
if img[1][1]>100:#如果是白底黑字 转成黑底白字 并进行归一化
for i in range(gao):#tensor 蓝95 绿200 红 147 ->红 绿 蓝
for j in range(kuan):
if img[i][j]<100:
img[i][j] = 0.95
else:
img[i][j] = -1
else:#否者 进行归一化
for i in range(gao):#tensor 蓝95 绿200 红 147 ->红 绿 蓝
for j in range(kuan):
if img[i][j]>100:
img[i][j]=0.95
else:
img[i][j]=-1
"""cv2.imshow('img2',img)#查看归一化后的图片
cv2.waitKey(0)#按任意键继续
cv2.destroyAllWindows()"""
torch_img = torch.from_numpy(img)#将ndarray 数组转成Tensor
torch_img = torch .squeeze(torch_img)#函数可以删除数组形状中的单维度条目
torch_img = torch_img.unsqueeze(0)#给第0维添加一个维度变为三维
torch_img = torch_img.unsqueeze(0)#给第0维添加一个维度变为四维
torch_img=torch_img.float()#将unit8类型转为float32
return torch_img
读取模型
model.conv1.load_state_dict(torch.load('model.pkl')['conv1'])
model.dense.load_state_dict(torch.load('model.pkl')['dense'])
测试代码
#验证单个图片
#输入 使用cpu预测
#model = model.cpu()
x_test=shouxiefix(path_tu= '006.png')#图片地址'006.png' 和这个文件一个文件夹
#x_test = x_test.cpu()
x_test = Variable(x_test)#将Tensor 转为 模型输入类型
y_pred = model(x_test.data)
_, pred = torch.max(y_pred,1)
#print(y_pred)
print(int(pred[0]))
输出预测结果
6
输入图片