【pytorch学习】训练一个分类器:代码逐行注释

导包

import torch
import torchvision
import torchvision.transforms as transforms
# torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。
# torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
# torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
# torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
# torchvision.utils: 其他的一些有用的方法。

准备数据集

transform = transforms.Compose(
	[transforms.ToTensor(),
	transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# torchvision.transforms.Compose()类:串联多个图片变换的操作
# transforms.ToTensor()函数:将原始的PILImage格式或者numpy.array格式的数据格式化为可被pytorch快速处理的张量类型。将数据归一化到[0,1]。
# transforms.Normalize(mean,std):对数据按通道进行标准化,即减去均值,再除以方差
# mean:(list)长度与输入的通道数相同,代表每个通道上所有数值的平均值。
# std:(list)长度与输入的通道数相同,代表每个通道上所有数值的标准差。

# 训练集
trainset = torchvision.datasets.CIFAR10(root='/path/to/data', train=True,
	download=True, transform=transform)
	trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
	shuffle=True, num_workers=2)
# shuffle设置为True时会在每个epoch重新打乱数据(默认: False)
# num_workers用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)

# 测试集
testset = torchvision.datasets.CIFAR10(root='/path/to/data', train=False,
	download=True, transform=transform)
	testloader = torch.utils.data.DataLoader(testset, batch_size=4,
	shuffle=False, num_workers=2)

# 类别
classes = ('plane', 'car', 'bird', 'cat',
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

定义卷积神经网络

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
	def __init__(self):
		super(Net, self).__init__()
		self.conv1 = nn.Conv2d(3, 6, 5)
		self.pool = nn.MaxPool2d(2, 2)
		self.conv2 = nn.Conv2d(6, 16, 5)
		self.fc1 = nn.Linear(16 * 5 * 5, 120)
		self.fc2 = nn.Linear(120, 84)
		self.fc3 = nn.Linear(84, 10)
	
	def forward(self, x):
		x = self.pool(F.relu(self.conv1(x)))
		x = self.pool(F.relu(self.conv2(x)))
		x = x.view(-1, 16 * 5 * 5)
		x = F.relu(self.fc1(x))
		x = F.relu(self.fc2(x))
		x = self.fc3(x)
		return x

net = Net()

定义损失函数和optimizer

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
# 使用交叉熵损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Optimizer使用带冲量的SGD

训练网络

for epoch in range(200):  # 迭代次数:200
	running_loss = 0.0 #每次初始loss
	for i, data in enumerate(trainloader, 0): # 遍历数据集
		# 获取输入
		inputs, labels = data
		
		# 梯度清零 
		optimizer.zero_grad()
		
		# forward + backward + optimize
        # 输入到网络得到输出
		outputs = net(inputs)
        # 计算loss
		loss = criterion(outputs, labels)
        # 反向传播,计算当前梯度
		loss.backward()
        # 根据梯度更新网络参数
		optimizer.step()
		
		# 定义统计信息
        # loss.item():取一个元素张量里面的具体元素值并返回该值。防止tensor无限叠加导致的显存爆炸
		running_loss += loss.item()
		if i % 2000 == 1999:# 每2000次输出平均loss
			print('[%d, %5d] loss: %.3f' %
				(epoch + 1, i + 1, running_loss / 2000))
		    running_loss = 0.0 # 一个2000次结束后,就把running_loss归零,下一个2000次继续使用

print('Finished Training')

Tips:

① pytorch利用Autograd模块进行自动求导,反向传播

② 默认的梯度会累加,因此我们通常在backward之前清除掉之前的梯度值

在测试数据集上进行测试

# dataloader本质上是一个可迭代对象,可以使用iter()进行访问,采用iter(dataloader)返回的是一个迭代器
# iter()函数用于将可迭代对象转换为迭代器
# iter(dataloader)访问时,imgs在前,labels在后,分别表示:图像转换0~1之间的值,labels为标签值
dataiter = iter(testloader)

# 使用next()访问iter(dataloader)返回的迭代器
# next()函数用于获取迭代器的下一个元素
# 一般来说,我们会在一个循环中多次调用 dataiter.next() 来获取训练数据,直到遍历完整个数据集
# 每次调用 dataiter.next(),我们都会得到一个大小为批量大小的数据集合,其中包含了图像和对应的标签
images, labels = dataiter.next()

# torchvision.utils.make_grid 将一个batch的图片在一张图中显示
imshow(torchvision.utils.make_grid(images))

# 输出4张图片的类别标签
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

测试结果示例

查看每个分类的准确率

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad(): # 所有计算得出的tensor的requires_grad都自动设置为False,反向传播时不再自动求导
	for data in testloader: # 按batch遍历数据
		images, labels = data
		outputs = net(images)
		_, predicted = torch.max(outputs, 1) # 各类别数值可视为概率,预测结果为输出中的最大概率对应的类别
		c = (predicted == labels).squeeze() # 对比预测类别是否与真实类别一致(一致则1,不一致则0),且将获得的结果降低到一维。
		for i in range(4): # 遍历该batch
			label = labels[i] # 真实类别标签
			class_correct[label] += c[i].item() # 正确数加1。c[i]保存对应预测是否正确的信息。
			class_total[label] += 1 # 标签图片数加一


for i in range(10): # 10个类别
	print('Accuracy of %5s : %2d %%' % (
		classes[i], 100 * class_correct[i] / class_total[i])) 
Accuracy of plane : 52 %
Accuracy of   car : 66 %
Accuracy of  bird : 49 %
Accuracy of   cat : 34 %
Accuracy of  deer : 30 %
Accuracy of   dog : 45 %
Accuracy of  frog : 72 %
Accuracy of horse : 71 %
Accuracy of  ship : 76 %
Accuracy of truck : 55 %

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值