pytorch学习笔记(六)
一、多分类问题相关知识
多分类问题实战:MNIST数据集是经典图像数据集,包括10个类别(0到9)。每一张图片拉成向量表示。
MNIST 数据集(手写数字数据集)来自美国国家标准与技术研究所. 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50%来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。MNIST数据集下载地址: http://yann.lecun.com/exdb/mnist/。手写数字的MNIST数据库包括60,000个的训练集样本,以及10,000个测试集样本。
二、代码实现
import torch
import torch.nn as nn
import torchvision as tv
# 超参数
batch_size=200
learning_rate=0.01
epochs=10
# 训练集
train_loader = torch.utils.data.DataLoader(
tv.datasets.MNIST('../data', train=True, download=True, # train=True则得到的是训练集
transform=tv.transforms.Compose([ # transform进行数据预处理
tv.transforms.ToTensor(), # 转成Tensor类型的数据
tv.transforms.Normalize((0.1307,), (0.3081,)) # 进行数据标准化(减去均值除以方差)
])),
batch_size=batch_size, shuffle=True) # 按batch_size分出一个batch维度在最前面,shuffle=True打乱顺序
# 测试集
test_loader = torch.utils.data.DataLoader(
tv.datasets.MNIST('../data', train=False, transform=tv.transforms.Compose([
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
# 设定参数w和b
w1, b1 = torch.randn(200, 784, requires_grad=True),\
torch.zeros(200, requires_grad=True) # w1(out,in)
w2, b2 = torch.randn(200, 200, requires_grad=True),\
torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True),\
torch.zeros(10, requires_grad=True)
torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)
def forward(x):
x = x@w1.t() + b1
x = torch.nn.function.relu(x)
x = x@w2.t() + b2
x = torch.nn.function.relu(x)
x = x@w3.t() + b3
x = torch.nn.function.relu(x)
return x
#定义sgd优化器,指明优化参数、学习率
optimizer = torch.optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criteon = nn.CrossEntropyLoss()
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
data = data.view(-1, 28*28)
logits = forward(data) # 把数据放入神经网络得出pred的值
loss = criteon(logits, target) # 用loss函数计算pred和target的差
optimizer.zero_grad() # 清零梯度
loss.backward() # 重新计算梯度
optimizer.step() # 用新的梯度计算新的w,b,然后迭代
if batch_idx % 100 == 0: #每100个batch输出一次信息
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
test_loss = 0
correct = 0 #correct记录正确分类的样本数
for data, target in test_loader:
data = data.view(-1, 28 * 28)
logits = forward(data)
test_loss += criteon(logits, target).item() #其实就是criteon(logits, target)的值,标量
pred = logits.data.max(dim=1)[1] # 得出pred的最大值,就是网络识别出的数字
correct += pred.eq(target.data).sum()
test_loss /= len(test_loader.dataset)
print(' {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))