参考书籍:《深度学习框架pytorch快速开发与实战》
1、导入常用包
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
2、设置超参数
BATCH_SIZE = 50
EPOCH = 3
LR = 0.001
3、数据集下载及处理
(转化为pytorch处理的tensor格式)
train_data = torchvision.datasets.MNIST(
root = './mnist',
train = True,
transform = torchvision.transforms.ToTensor(),
download = True
)
test_data = torchvision.datasets.MNIST(
root = './mnist',
train = False,
transform = torchvision.transforms.ToTensor(),
download = True
)
print(train_data.data.size())
print(test_data.data.size())
DataLoader可以把数据集分割为batch_size大小
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
#test_x和test_y,维度上的一些处理
test_x = test_data.data.reshape(-1,1,28,28) #1是通道数,卷积核数量
test_x = torch.true_divide(test_x,255) #归一化,像素255
test_y = test_data.targets
4、模型构建
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__(); #继承Module
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2), #默认bias=True
nn.ReLU(),
nn.MaxPool2d(kernel_size=2) #pooling
)
self.conv2 = nn.Sequential(
nn.Conv2d(16,32,5,1,2), #参见conv1中的参数顺序
nn.ReLU(),
nn.MaxPool2d(2)
)
self.out = nn.Linear(32*7*7, 10) #input_feature,out_feature,bias=True or False
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0),-1) #linear的输入输出都是一维
output = self.out(x)
return output
5、实例化模型
cnn = CNN()
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_function = nn.CrossEntropyLoss() #交叉熵
6、训练模型、评价指标
for epoch in range(EPOCH):
for step,(x,y) in enumerate(train_loader):
b_x = Variable(x)
b_y = Variable(y)
output = cnn(b_x)
loss = loss_function(output,y)
optimizer.zero_grad() #梯度清0
loss.backward()
optimizer.step() #下一次更新
if step%1000 == 0:
test_output = cnn(test_x)
pred_y = torch.max(test_output,1)[1].data.squeeze()
accuracy =torch.true_divide(sum(pred_y==test_y), test_y.size(0))
print('Epoch:',epoch,'|step:',step,'|train loss:%4f' % loss.item(),
'test accuracy:%.4f' % accuracy)
7、看一下预测结果
test_output = cnn(test_x[:20])
# print(test_output)
pred_y = torch.max(test_output,1)[1].data.squeeze() #最大元素对应的下标
print(pred_y[:20],'prediction number')
print(test_y[:20],'real number')