参考地址:https://github.com/kuangliu/pytorch-cifar
一.搭建神经网络
import torch.nn as nn
import torch.nn.functional as F
# 继承了nn.Module类
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
二,应用搭建的DNN
以命令行运行时,可以允许加入以下参数:
import argparse
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
# 使用args.lr,args.resume可以获得参数
数据增强和导入:
import torchvision
import torchvision.transforms as transforms
cifar_norm_mean = (0.49139968, 0.48215827, 0.44653124)
cifar_norm_std = (0.24703233, 0.24348505, 0.26158768)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(cifar_norm_mean, cifar_norm_std),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cifar_norm_mean, cifar_norm_std),
])
trainset = torchvision.datasets.CIFAR10\
(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader\
(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10\
(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader\
(testset, batch_size=100, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
定义网络和设备
import torch
import lenet
import torch.backends.cudnn as cudnn
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = lenet.LeNet()
net = net.to(device) # 将net基于设备运行
if device == 'cuda':
net = torch.nn.DataParallel(net) # 单机多GPUs的并行处理,只有1个GPU这句话没用
cudnn.benchmark = True # 将选择最佳优化算法,加快训练速度
训练网络
import torch
import torch.nn as nn
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
# 遍历训练集, 梯度下降训练的框架
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad() # 清空梯度,防止梯度被累积
outputs = net(inputs) # 将训练数据输入网络得到输出
loss = criterion(outputs, targets) # 用定义的交叉熵函数对象计算误差
loss.backward() # 计算反向梯度
optimizer.step() # 利用计算好的梯度更新参数
# 输出loss和预测信息
train_loss += loss.item() # 将tensor([0.98])变为标量0.98
_, predicted = outputs.max(1) # 1表示对第一维约简,_为返回的最大值,predicted为最大值的下标
total += targets.size(0) # .size返回一个元组表示维度
correct += predicted.eq(targets).sum().item() # .eq返回一个原维度的bool tensor
print('Train accuracy:', correct/total)
在测试集上测试
def test(epoch):
global best_acc
net.eval() # 将net设为评估模式,对dropout和batchnorm将变为评估模式
test_loss = 0
correct = 0
total = 0
# torch.no_grad()使得autograd引擎失效,会加快计算和减少内存占用
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
print('Test accuracy:', correct/total)
参数保存与网络断点的设置
字典的形式将信息保存到pkl文件中:
state = {
'net': net.state_dict(), # 获取当前net的参数字典,为OrderedDict
'acc': acc, # 当前准确率
'epoch': epoch, # 当前轮数
'optimizer': optimizer.state_dict()
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
# 保存状态字典到pkl文件
torch.save(state, './checkpoint/ckpt.pkl')
# 也可以直接将参数字典net.state_dict()保存到pkl文件
torch.save(state, './checkpoint/params.pkl')
加载断点,使得DNN在断点状态下继续运行
checkpoint = torch.load('./checkpoint/ckpt.pkl')
net.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']