点击上方
“小白学视觉
”,选择加"星标"或“置顶”
重磅干货,第一时间送达
关键知识点
前面我们刚刚组队完毕,更新了第一篇,我说我会坚持写下去,这个是我的第二篇,使用pytorch实现简单神经网络完成手写数字识别。这个是所有深度学习框架入门标配的例子,但是从这个例子上我们可以学到pytorch的很多基础知识点,我罗列一下,大致有如下:
1.开始用torch.nn包里面的函数搭建网络
2.模型保存为pt文件与加载调用
3.Torchvision.transofrms来做数据预处理
4.DataLoader简单调用处理数据集
只有理解和看清以上四点才算入门了这个例子。
数据集:
Mnist数据集,数字为0~9、大小为28x28的灰度图像。
加载数据集代码实现:
train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform) test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform) train_dl = DataLoader(train_ts, batch_size=32, shuffle=True, drop_last=False) test_dl = DataLoader(test_ts, batch_size=64, shuffle=True, drop_last=False)
预处理数据方式
transform = tv.transforms.Compose([tv.transforms.ToTensor(),tv.transforms.Normalize((0.5,), (0.5,)),])
其中
Totensor表示把灰度图像素值从0~255转化为0~1之间
Normalize表示对输入的减去0.5, 除以0.5
网络结构如下:
输入层:784个神经元
隐藏层:100个神经元
输出层:10个神经元
model = t.nn.Sequential( t.nn.Linear(784, 100), t.nn.ReLU(), t.nn.Linear(100, 10), t.nn.LogSoftmax(dim=1) )
定义损失函数与优化函数
"mean")
开启训练
for s
测试模型准确率
0;
打印准确率与保存模型
print("total acc : %.2f\n"%(correct_count / total)) t.save(model, './nn_mnist_model.pt')
完整演示代码
import torch
运行结果: