- Pytorch搭建简单神经网络完成手写数字识别,其中利用到的知识点如下:
(1)用torch.nn包里面的函数搭建网络
(2)Torchvision.transofrms来做数据预处理
(3)DataLoader简单调用处理数据集
(4)模型保存为pt文件与加载调用
下面看完整的代码:
"""
1.使用torch.nn包里面的搭建网络
2.Torchvision.transofrms来做数据处理
3.DataLoader简单调用处理数据集
4.模型保存为pt文件与加载调用
"""
import torch as t
from torch.utils.data import DataLoader
import torchvision as tv
#预处理数据方式
##ToTensor表示把灰度图像像素值从0-255转为0-1之间
##Normalize表示把输入的减去0.5,除以0.5
transform = tv.transforms.Compose([tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5,),(0.5,)),
])
#数据集(Mnist数据集,数字为0-9,大小为28*28的灰度图像)
##加载此数据集,这里pytorch已经写好调用接口,直接运行,没有此数据集的话,会自动下载
train_ts = tv.datasets.MNIST(root='./data',train=True,download=True,transform=transform)
test_ts = tv.datasets.MNIST(root='./data',train=False,download=True,transform=transform)
train_dl = DataLoader(train_ts,batch_size=32,shuffle=True,drop_last=False)
test_dl = DataLoader(test_ts,batch_size=64,shuffle=True,drop_last=False)
#网络结构
##输入层:784个神经单元,隐藏层:100个神经单元,输出层:10个神经单元
model = t.nn.Sequential(
t.nn.Linear(784,100),
t.nn.ReLU(),
t.nn.Linear(100,10),
t.nn.LogSoftmax(dim=1)
)
#定义损失函数与优化函数
loss_fn = t.nn.NLLLoss(reduction='mean')
optimizer = t.optim.Adam(model.parameters(),lr=1e-3)
#开启训练
for s in range(5):
print ('run in step :{}'.format(s))
for i,(x_train,y_train) in enumerate(train_dl):
x_train = x_train.view(x_train.shape[0],-1)
y_pred = model(x_train)
train_loss = loss_fn(y_pred,y_train)
if (i+1)%100 == 0:
print (i+1,train_loss.item())
model.zero_grad()
train_loss.backward()
optimizer.step()
#测试模型准确性
total = 0
correct_count = 0
for test_images,test_labels in test_dl:
for i in range(len(test_labels)):
image = test_images[i].view(1,784)
with t.no_grad():
pred_labels = model(image)
plabels = t.exp(pred_labels)
probs = list(plabels.numpy()[0])
pred_label = probs.index(max(probs))
true_label = test_labels.numpy()[i]
if pred_label == true_label:
correct_count += 1
total += 1
#打印准确率,保存模型
print('total acc:{}'.format(correct_count/total))
t.save(model,'./nn_mnist_model.pt')
运行结果:
1100 0.11727560311555862
1200 0.08096164464950562
1300 0.23481614887714386
1400 0.017000077292323112
1500 0.03314385190606117
1600 0.10333727300167084
1700 0.04443451762199402
1800 0.024772992357611656
total acc:0.9642