Pytorch搭建LeNet网络---代码详解

搭建LeNet网络—代码详解

img

首先是构建模型的文件model.py

根据上面的结构图,有如下代码:

import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.conv1=nn.Conv2d(3,16,5)#卷积层的参数依次是(输入channels,输出channels,卷积核的大小)
        self.pool1=nn.MaxPool2d(2,2)
        self.conv2=nn.Conv2d(16,32,5)
        self.pool2=nn.MaxPool2d(2,2)
        self.fc1=nn.Linear(32*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)

    def forward(self,x):
        x=F.relu(self.conv1(x))   #input(3,32,32)   output(16,28,28)
        x=self.pool1(x)           #output(16,14,14)
        x=F.relu(self.conv2(x))   #output(32,10,10)
        x=self.pool2(x)           #output(32,5,5)
        x=x.view(-1,32*5*5)       #output(32*5*5)
        x=F.relu(self.fc1(x))     #output(120)
        x=F.relu(self.fc2(x))     #output(84)
        x=self.fc3(x)             #output(10)
        return x



然后是用于训练的文件train.py

代码如下:

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np


transform=transforms.Compose(
    [transforms.ToTensor(),#将(H*W*C),范围(0-255) 转换成(C*H*W),范围0-1
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]#标准化,分别表示均值和方差
)

#50000张训练图片,用torchvision中自带的CIFAR10数据集,直接下载(如果已经下载好的就将download设为False)
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
#用trainloader加载训练数据,batch_size代表一次加载36张图片,linux下可以将num_worker设为大于0的数
trainloader=torch.utils.data.DataLoader(trainset,batch_size=36,shuffle=True,num_worker=0)

#按同样的方法导入测试集的10000张测试图片
testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=False,transform=transform)
#一次就测试完所有的测试图片
testloader=torch.utils.data.DataLoader(testset,batch_size=1000,shuffle=False,num_workers=0)

#转换成可迭代的迭代器
test_data_iter=iter(testloader)
test_image,test_label=test_data_iter.next()

classes=('plane','cat','bird','cat','deer','dog','frog','horse','ship','truck')#元组类型,不可改变

#显示图片

def imshow(img):
    img=img/2+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()

#打印标签
print(''.join('%5s'% classes[test_label[j]] for j in range(4)))
#展示图片
imshow(torchvision.utils.make_grid(test_image))

net=LeNet()
loss_function=nn.CrossEntropyLoss()#这个损失器中包含了softmax
optimizer=optim.Adam(net.parameters(),lr=0.001)

#训练过程
for epoch in range(5):
    running_loss=0.0
    for step,data in enumerate(trainloader,start=0):#索引从0开始
        inputs,labels=data

        #梯度归零
        optimizer.zero_grad()

        #forward+backward+optimize
        outputs=net(inputs)
        loss=loss_function(outputs,labels)#计算损失
        loss.backward()#反向传播
        optimizer.step()#更新参数

        #print statistics
        running_loss+=loss.item()
        if step %500==499:#每训练500个batch就测试 测试图片
            with torch.no_grad():
                outputs=net(test_image)
                predict_y=torch.max(outputs,dim=1)[1]
                accuracy=(predict_y==test_label).sum().item/test_label.size(0)

                print('[%d,%5d] train_loss:%.3f test_accuracy:%.3f' %(epoch+1,step+1,running_loss/500,accuracy))
                running_loss=0.0
print('Finished Training')
save_path='./Lenet.pth'
torch.save(net.state(dict(),save_path))#将参数保存起来,之后用的时候可以直接加载

保存好模型参数之后,会多出一个Lenet.pth的文件,我们自己下载一张图片,这里下载一张飞机的图片用来做预测

代码如下:

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

transform=transforms.Compose(
    [transforms.Resize((32,32)),
     transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)

classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

net=LeNet()
net.load_state_dict(torch.load('Lenet.pth'))#加载预训练好的模型

im=Image.open('1.jpg')#将我们下载好的图片用来做预测                                   
im=transforms(im)
im=torch.unsqueeze(im,dim=0)

with torch.no_grad():
    outputs=net(im)
    predict=torch.max(outputs,dim=1)[1].data.numpy()

print(classes[int(predict)])

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值