LeNet网络搭建与基本训练流程

模型

1_vUJ-XilD6_WECeQlOMThMQ

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()   # 解决继承父类中出现的一系列问题
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))       # 输入(3,32,32) 输出(16,28,28)
        x = self.pool1(x)               # 输出(16,14,14)
        x = F.relu(self.conv2(x))       # 输出(32,10,10)
        x = self.pool2(x)               # 输出(32,5,5)
        x = x.view(-1, 32*5*5)          # 输出(32*5*5),batch=-1设为动态调整这个维度上的元素的个数,以保证元素的总数不变
        x = F.relu(self.fc1(x))         # 输出(120)
        x = F.relu(self.fc2(x))         # 输出(84)
        x = self.fc3(x)                 # 输出(10)
        return x

预处理

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  1. 转换图片数据为pytorch中的Tensor格式
  2. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))将数据转换为标准正态分布,即逐个c h a n n e l channelchannel的对图像进行标准化(均值变为0 00,标准差为1 11),可以加快模型的收敛

加载数据集

# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36, shuffle=True, num_workers=0)

# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000, shuffle=False, num_workers=0)

test_data_iter = iter(testloader)
test_image, test_label = test_data_iter.next()

classes = ('plane', 'car', 'bird', 'cat', 'deer'
           'dog', 'frog', 'horse', 'ship', 'truck')

加载cifar10数据集,然后取出测试集的图片和标签

image-20220707211351582

image-20220707211309072

训练

1.加载模型、定义损失函数、优化器

net = LeNet()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
  • nn.CrossEntropyLoss():使用交叉熵损失函数
  • optim.Adam(net.parameters(), lr=0.001):net.parameters()是将net的参数都丢进优化器里
  • net = LeNet():注意LeNet后面有一个()

136aff55f926e455

2.训练循环

def train_process():
    for epoch in range(10):
        running_loss = 0.0
        for step, data in enumerate(trainloader, start=0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if step % 500 ==499:
                with torch.no_grad():
                    outputs = net(test_image)
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = (predict_y == test_label).sum().item() / test_label.size(0)

                    print('[%d, %5d] train_loss: %.3f test_accuracy:%.3f'
                          %(epoch+1, step+1, running_loss/500, accuracy))
                    running_loss = 0.0
    print("Finished Training")

train_process()

save_path = 'Lenet.pth'
torch.save(net.state_dict(),save_path)

  • for step, data in enumerate(trainloader, start=0)

enumerate()函数

  • 示例:for step, data in enumerate(trainloader, start=0):

  • 作用:将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标

  • 参数:

    1. sequence:一个序列、数组或其它对象
    2. start:下标起始位置的值
  • 实例

image-20220627142851807

  • predict_y = torch.max(outputs, dim=1)[1]:dim=1选择行中最大的概率值,dim=0选择列最大的概率值。[1]表示切片,因为torch.max会返回两个数值,第一个是这个概率值,第二个是🌈序号。正常可以这么写: _, predicted = torch.max(outputs.data,dim=1)

  • accuracy = (predict_y == test_label).sum().item() / test_label.size(0):如果正确的预测累加,通过item转换为数值,除以总的测试长度,得到正确的结果

  • torch.save(net.state_dict(),save_path):保存权重文件(.pth)

测试

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

transform = transforms.Compose([transforms.Resize((32,32)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

classes = ('plane', 'car', 'bird', 'cat', 'deer'
           'dog', 'frog', 'horse', 'ship', 'truck')

net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))

img = Image.open('data/plane.png')
img = transform(img)
img = torch.unsqueeze(img, dim=0)

with torch.no_grad():
    outputs = net(img)
    # predict = torch.max(outputs,dim=1)[1].data.numpy()
    predict = torch.softmax(outputs, dim=1)

print(predict)
# print(classes[int(predict)])
  • net.load_state_dict(torch.load('Lenet.pth')):加载权重
  • with torch.no_grad()::不反向传播计算梯度,减少计算量
  • print(classes[int(predict)]):与class联用,通过索引直接输出标签
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

落叶随峰

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值