LeNet和CIFAR10数据集为例解析Pytorch常用操作

参考地址:https://github.com/kuangliu/pytorch-cifar

一.搭建神经网络

创建lenet.py:

import torch.nn as nn
import torch.nn.functional as F

# 继承了nn.Module类
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

二,应用搭建的DNN

以命令行运行时,可以允许加入以下参数:

import argparse
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
# 使用args.lr,args.resume可以获得参数

数据增强和导入:

import torchvision
import torchvision.transforms as transforms

cifar_norm_mean = (0.49139968, 0.48215827, 0.44653124)
cifar_norm_std = (0.24703233, 0.24348505, 0.26158768)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar_norm_mean, cifar_norm_std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_norm_mean, cifar_norm_std),
])

trainset = torchvision.datasets.CIFAR10\
    (root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader\
    (trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10\
    (root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader\
    (testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

定义网络和设备

import torch
import lenet
import torch.backends.cudnn as cudnn

device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = lenet.LeNet()
net = net.to(device) # 将net基于设备运行
if device == 'cuda':
    net = torch.nn.DataParallel(net) # 单机多GPUs的并行处理,只有1个GPU这句话没用
    cudnn.benchmark = True # 将选择最佳优化算法,加快训练速度

训练网络

import torch
import torch.nn as nn
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    # 遍历训练集, 梯度下降训练的框架
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()  # 清空梯度,防止梯度被累积
        outputs = net(inputs)  # 将训练数据输入网络得到输出
        loss = criterion(outputs, targets) # 用定义的交叉熵函数对象计算误差
        loss.backward() # 计算反向梯度
        optimizer.step() # 利用计算好的梯度更新参数
        
		# 输出loss和预测信息
        train_loss += loss.item() # 将tensor([0.98])变为标量0.98
        _, predicted = outputs.max(1)  # 1表示对第一维约简,_为返回的最大值,predicted为最大值的下标
        total += targets.size(0) # .size返回一个元组表示维度
        correct += predicted.eq(targets).sum().item() # .eq返回一个原维度的bool tensor
		print('Train accuracy:', correct/total)

在测试集上测试

def test(epoch):
    global best_acc
    net.eval() # 将net设为评估模式,对dropout和batchnorm将变为评估模式
    test_loss = 0
    correct = 0
    total = 0
    # torch.no_grad()使得autograd引擎失效,会加快计算和减少内存占用
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            print('Test accuracy:', correct/total)

参数保存与网络断点的设置

字典的形式将信息保存到pkl文件中:

state = {
    'net': net.state_dict(),  # 获取当前net的参数字典,为OrderedDict
    'acc': acc,  # 当前准确率
    'epoch': epoch, # 当前轮数
    'optimizer': optimizer.state_dict() 
}
if not os.path.isdir('checkpoint'):
    os.mkdir('checkpoint')
# 保存状态字典到pkl文件
torch.save(state, './checkpoint/ckpt.pkl')
# 也可以直接将参数字典net.state_dict()保存到pkl文件
torch.save(state, './checkpoint/params.pkl')

加载断点,使得DNN在断点状态下继续运行

checkpoint = torch.load('./checkpoint/ckpt.pkl')
net.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值