1.网络结构的搭建
import torch.nn as nn
import torch.nn.functional as F
# 继承nn.Module
# 计算公式:N=(W-F+2P)/S+1,W为输入图片尺寸W×W,F为Filter大小,P为padding大小,S为stride
# F=dilation×(kernel_size-1)+1
class LeNet(nn.Module):
def __init__(self):
super(LeNet,self).__init__()
self.conv1=nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1,padding=0)
self.pool1=nn.MaxPool2d(kernel_size=2,stride=2)
self.conv2=nn.Conv2d(16,32,5)
self.pool2=nn.MaxPool2d(2,2)
self.fc1=nn.Linear(32*5*5,120)
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10)
def forward(self,x):
# input(3,32,32) F=1×(5-1)+1,N=(32-5+2×0)/1+1=28,output(16,28,28)
x=F.relu(self.conv1(x))
x=self.pool1(x) # output(16,14,14)
# N=(14-5)/1+1=10,output(16,10,10)
x=F.relu(self.conv2(x))
x=self.pool2(x) # output(32,5,5)
x=x.view(-1,32*5*5) # output(32*5*5)
x=F.relu(self.fc1(x)) # output(120)
x=F.relu(self.fc2(x)) # output(84)
x=self.fc3(x) # output(10)
return x
2.模型的训练
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
def main():
# 图片处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
# CIFAR10数据集,50000张训练图片
# 第一次下载使用时要将download设置为true才能自动去下载数据集
train_set = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=False,
transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=36,
shuffle=True,
num_workers=0
)
# 10000张验证图片
val_set = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=False,
transform=transform
)
val_loader = torch.utils.data.DataLoader(
val_set,
batch_size=10000,
shuffle=False,
num_workers=0
)
val_data_iter = iter(val_loader)
val_image,val_label = val_data_iter.next()
# classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# def imshow(img):
# img =img/2+0.5
# np_img= img.numpy()
# plt.imshow(np.transpose(np_img,(1,2,0)))
# plt.show()
# # 打印标签
# print(' '.join('%5s'% classes[val_label[j]] for j in range(4)))
# # 展示图片
# imshow(torchvision.utils.make_grid(val_image))
net = LeNet()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr=0.001)
for epoch in range(5):
running_loss = 0.0
for step,data in enumerate(train_loader,start=0):
# get the inputs:data is a list of[inputs,labels]
inputs,labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward+backward+optimize
outputs = net(inputs)
loss = loss_function(outputs,labels)
loss.backward()
optimizer.step()
# print statistics
running_loss+=loss.item()
if step%500 == 499:# print every 500 mini-batches
with torch.no_grad():
outputs = net(val_image)#[batch,10]
predict_y = torch.max(outputs,dim=1)[1]
accuracy = torch.eq(predict_y,val_label).sum().item()/val_label.size(0)
print('[%d,%5d] train_loss:%.3f test_accuracy:%.3f'%(epoch+1,step+1,running_loss/500,accuracy))
running_loss = 0.0
print("Finished Training!!!")
save_path = './Lenet.pth'
torch.save(net.state_dict(),save_path)
if __name__ =='__main__':
main()
3.模型的预测检验
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet
def main():
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((32,32)),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
classes=('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net=LeNet()
net.load_state_dict(torch.load('Lenet.pth'))
im = Image.open('1.jpg')
im = transform(im) # [C,H,W]
im = torch.unsqueeze(im,dim=0) # [N,C,H,W]
with torch.no_grad():
outputs = net(im)
predict = torch.max(outputs,dim=1)[1].numpy()
print(classes[int(predict)])
if __name__ == '__main__':
main()