搭建LeNet网络—代码详解
首先是构建模型的文件model.py
根据上面的结构图,有如下代码:
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1=nn.Conv2d(3,16,5)#卷积层的参数依次是(输入channels,输出channels,卷积核的大小)
self.pool1=nn.MaxPool2d(2,2)
self.conv2=nn.Conv2d(16,32,5)
self.pool2=nn.MaxPool2d(2,2)
self.fc1=nn.Linear(32*5*5,120)
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10)
def forward(self,x):
x=F.relu(self.conv1(x)) #input(3,32,32) output(16,28,28)
x=self.pool1(x) #output(16,14,14)
x=F.relu(self.conv2(x)) #output(32,10,10)
x=self.pool2(x) #output(32,5,5)
x=x.view(-1,32*5*5) #output(32*5*5)
x=F.relu(self.fc1(x)) #output(120)
x=F.relu(self.fc2(x)) #output(84)
x=self.fc3(x) #output(10)
return x
然后是用于训练的文件train.py
代码如下:
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
transform=transforms.Compose(
[transforms.ToTensor(),#将(H*W*C),范围(0-255) 转换成(C*H*W),范围0-1
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]#标准化,分别表示均值和方差
)
#50000张训练图片,用torchvision中自带的CIFAR10数据集,直接下载(如果已经下载好的就将download设为False)
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
#用trainloader加载训练数据,batch_size代表一次加载36张图片,linux下可以将num_worker设为大于0的数
trainloader=torch.utils.data.DataLoader(trainset,batch_size=36,shuffle=True,num_worker=0)
#按同样的方法导入测试集的10000张测试图片
testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=False,transform=transform)
#一次就测试完所有的测试图片
testloader=torch.utils.data.DataLoader(testset,batch_size=1000,shuffle=False,num_workers=0)
#转换成可迭代的迭代器
test_data_iter=iter(testloader)
test_image,test_label=test_data_iter.next()
classes=('plane','cat','bird','cat','deer','dog','frog','horse','ship','truck')#元组类型,不可改变
#显示图片
def imshow(img):
img=img/2+0.5
npimg=img.numpy()
plt.imshow(np.transpose(npimg,(1,2,0)))
plt.show()
#打印标签
print(''.join('%5s'% classes[test_label[j]] for j in range(4)))
#展示图片
imshow(torchvision.utils.make_grid(test_image))
net=LeNet()
loss_function=nn.CrossEntropyLoss()#这个损失器中包含了softmax
optimizer=optim.Adam(net.parameters(),lr=0.001)
#训练过程
for epoch in range(5):
running_loss=0.0
for step,data in enumerate(trainloader,start=0):#索引从0开始
inputs,labels=data
#梯度归零
optimizer.zero_grad()
#forward+backward+optimize
outputs=net(inputs)
loss=loss_function(outputs,labels)#计算损失
loss.backward()#反向传播
optimizer.step()#更新参数
#print statistics
running_loss+=loss.item()
if step %500==499:#每训练500个batch就测试 测试图片
with torch.no_grad():
outputs=net(test_image)
predict_y=torch.max(outputs,dim=1)[1]
accuracy=(predict_y==test_label).sum().item/test_label.size(0)
print('[%d,%5d] train_loss:%.3f test_accuracy:%.3f' %(epoch+1,step+1,running_loss/500,accuracy))
running_loss=0.0
print('Finished Training')
save_path='./Lenet.pth'
torch.save(net.state(dict(),save_path))#将参数保存起来,之后用的时候可以直接加载
保存好模型参数之后,会多出一个Lenet.pth的文件,我们自己下载一张图片,这里下载一张飞机的图片用来做预测
代码如下:
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet
transform=transforms.Compose(
[transforms.Resize((32,32)),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)
classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
net=LeNet()
net.load_state_dict(torch.load('Lenet.pth'))#加载预训练好的模型
im=Image.open('1.jpg')#将我们下载好的图片用来做预测
im=transforms(im)
im=torch.unsqueeze(im,dim=0)
with torch.no_grad():
outputs=net(im)
predict=torch.max(outputs,dim=1)[1].data.numpy()
print(classes[int(predict)])