# Pytorch学习系列-03-构建卷积神经网络实现手写数字识别(Mnist数据集)

• 涉及到主要知识点是如何使用torch.nn.Module这个基类来实现一个网络结构定义。这个基类中最重要的是实现自己的forward方法，这个也是自定义网络结构的实现方法。
• 保存了模型之后，还可以转化为ONNX格式，把模型送给OpenCV DNN模块调用。

1.网络训练、验证及模型导出与转换代码：

import torch
from torch.utils.data import DataLoader
import torchvision

#预处理数据
transfrom = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,),(0.5,))])

#加载数据集

#构建卷积升级网络
#in_channels表示输入通道数目
#out_channels表示输出通道数目
class CNN_NET(torch.nn.Module):
def __init__(self):
super(CNN_NET,self).__init__()
self.cnn_layers = torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2,stride=2),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2,stride=2),
torch.nn.ReLU(),
)
self.fc_layers = torch.nn.Sequential(
torch.nn.Linear(7*7*32,200),
torch.nn.ReLU(),
torch.nn.Linear(200,100),
torch.nn.ReLU(),
torch.nn.Linear(100,10),
torch.nn.LogSoftmax(dim=1)
)

def forward(self,x):
out = self.cnn_layers(x)
out = out.view(-1,7*7*32)
out = self.fc_layers(out)
return out

model = CNN_NET()
#损失函数与优化器
loss_fn = torch.nn.CrossEntropyLoss()

#训练5个epoch
for s in range(5):
print ('run in step:{}'.format(s))
for i,(x_train,y_train) in enumerate(train_dl):
y_pred = model.forward(x_train)
train_loss = loss_fn(y_pred,y_train)
if (i+1)%100 == 0:
print (i+1,train_loss.item())
train_loss.backward()
optimizer.step()

total = 0
correct_count = 0

#模型评估
model.eval()
for test_images,test_labels in test_dl:
pred_labels = model(test_images)
predicted = torch.max(pred_labels,1)[1]
correct_count  += (predicted == test_labels).sum()
total += len(test_labels)
print (correct_count.detach().numpy(),total)

print ('total acc:{}'.format(correct_count.detach().numpy()/total))
torch.save(model,'./cnn_mnist_model.pt')

#转为ONNX格式
dummy_input = torch.randn(1,1,28,28)
torch.onnx.export(model,dummy_input,'cnn_mnist.onnx',verbose=True)


2.利用OpenCV DNN调用导出的模型：

import cv2
import numpy as np

cv2.imshow('img',img)
blob = cv2.dnn.blobFromImage(img,0.00392,(28,28),(0.5))/0.5
print (blob.shape)
mnist_net.setInput(blob)
result = mnist_net.forward()
pred_label = np.argmax(result,1)
print ('predit label:{}'.format(pred_label))
cv2.waitKey(0)
cv2.destroyAllWindows()


• 点赞
• 评论
• 分享
x

海报分享

扫一扫，分享海报

• 收藏
• 手机看

分享到微信朋友圈

x

扫一扫，手机阅读

• 打赏

打赏

骚火棍

你的鼓励将是我创作的最大动力

C币 余额
2C币 4C币 6C币 10C币 20C币 50C币
• 一键三连

点赞Mark关注该博主, 随时了解TA的最新博文

01-10 471
01-24 768
04-30 2933
04-20 320
02-06 283
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客