pytorch保存准确率_Pytorch轻松学构建浅层神经网络

621ec440e08cf5178ebb5e96899b5868.gif

点击上方蓝字关注我们

微信公众号:OpenCV学堂

关注获取更多计算机视觉与深度学习知识

关键知识点

前面我们刚刚组队完毕,更新了第一篇,我说我会坚持写下去,这个是我的第二篇,使用pytorch实现简单神经网络完成手写数字识别。这个是所有深度学习框架入门标配的例子,但是从这个例子上我们可以学到pytorch的很多基础知识点,我罗列一下,大致有如下:

1.开始用torch.nn包里面的函数搭建网络
2.模型保存为pt文件与加载调用
3.Torchvision.transofrms来做数据预处理
4.DataLoader简单调用处理数据集

只有理解和看清以上四点才算入门了这个例子。

数据集:

Mnist数据集,数字为0~9、大小为28x28的灰度图像。

加载数据集代码实现:

train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform) train_dl = DataLoader(train_ts, batch_size=32, shuffle=True, drop_last=False) test_dl = DataLoader(test_ts, batch_size=64, shuffle=True, drop_last=False)

预处理数据方式

transform = tv.transforms.Compose([tv.transforms.ToTensor(),tv.transforms.Normalize((0.5,), (0.5,)),])

其中

Totensor表示把灰度图像素值从0~255转化为0~1之间

Normalize表示对输入的减去0.5, 除以0.5

网络结构如下:

8c640df1673bc7ba353405d90e3bed00.png

输入层:784个神经元

隐藏层:100个神经元

输出层:10个神经元

model = t.nn.Sequential(     t.nn.Linear(784, 100),     t.nn.ReLU(),     t.nn.Linear(100, 10),     t.nn.LogSoftmax(dim=1) )

定义损失函数与优化函数

8c640df1673bc7ba353405d90e3bed00.png
"mean")

开启训练

8c640df1673bc7ba353405d90e3bed00.png
for s 

测试模型准确率

8c640df1673bc7ba353405d90e3bed00.png
0;

打印准确率与保存模型

8c640df1673bc7ba353405d90e3bed00.png
print("total acc : %.2f\n"%(correct_count / total)) t.save(model, './nn_mnist_model.pt')

完整演示代码

import torch 

运行结果:

02311dfa55d96b6066c21434f9bc3377.png

 推荐阅读 

轻松学Pytorch–环境搭建与基本语法

不积硅步,无以至千里

d0c8c649292eb649cde698c25bf69a5d.png

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值