算法思路:
这段代码是一个简单的基于PyTorch的Fashion-MNIST分类模型训练过程。下面对代码进行逐行解释
算法流程:
import torch
import torchvision
from torchvision import transforms
from torch.utils import data
首先导入了需要使用的库和模块,这些库和模块包含了构建和训练模型所需的功能。
def load_data_fashion_mnist(batch_size, resize=None):
"""下载Fashion-MNIST数据集,然后将其加载到内存中。"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize)) # 对图片进行扩充
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root="../data", # 下载数据集中的训练集
train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", # 下载数据集中的测试集
train=False,
transform=trans,
download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True, ),
data.DataLoader(mnist_test, batch_size, shuffle=False, ))
Doad_data_fashion_mnist
函数用于下载和加载Fashion-MNIST数据集, 这个函数返回两个DataLoader
对象,分别用于训练集和测试集的数据加载。
def softmax(X):
X_exp = torch.exp(X)
partition = X_exp.sum(1, keepdim=True)
return X_exp / partition
softmax
函数用于计算Softmax激活函数,这个函数首先计算参数的指数值,然后进行归一化,并进行返回。
def net(X):
return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)
net
函数定义了模型的前向传播过程。这个函数接收输入X
,将其与权重W
进行矩阵乘法和偏置b
相加,然后通过Softmax函数得到预测结果。
def cross_entropy(y_hat, y):
return -torch.log(y_hat[range(len(y_hat)), y])
cross_entropy
函数计算交叉熵损失:这个函数接收预测结果y_hat
和真实标签y
,并计算它们之间的交叉熵损失。
def accuracy(y_hat, y):
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1)
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())
accuracy
函数计算模型的准确率:这个函数接收预测结果y_hat
和真实标签y
,根据预测结果的最大值判断准确与否,并返回准确率。
def evaluate_accuracy(net, data_iter):
"""计算在指定数据集上模型的精度。"""
if isinstance(net, torch.nn.Module):
net.eval()
metric = Accumulator(2)
for X, y in data_iter:
metric.add(accuracy(net(X), y), y.numel())
return metric[0] / metric[1]
class Accumulator:
"""在`n`个变量上累加。"""
def __init__(self, n):
self.data = [0.0] * n
def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]
def reset(self):
self.data = [0.0] * len(self.data)
def __getitem__(self, idx):
return self.data[idx]
该函数适用于计算一般模型的准确度,在后续流程中,我们使用该准确率函数进行度量。
def train_epoch_ch3(net, train_iter, loss, updater):
"""训练模型一个迭代周期(定义见第3章)。"""
if isinstance(net, torch.nn.Module):
net.train()
metric = Accumulator(3)
for X, y in train_iter:
y_hat = net(X)
l = loss(y_hat, y)
if isinstance(updater, torch.optim.Optimizer):
updater.zero_grad()
l.backward()
updater.step()
metric.add(
float(l) * len(y), accuracy(y_hat, y),
y.size().numel())
else:
l.sum().backward()
updater(X.shape[0])
metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
return metric[0] / metric[2], metric[1] / metric[2]
train_epoch_ch3
函数定义了一个训练模型的迭代周期:这个函数接收模型net
、训练数据加载器train_iter
、损失函数loss
和更新器updater
,在每个迭代周期中进行模型训练,并返回训练损失和准确率。
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):
"""训练模型(定义见第3章)。"""
for epoch in range(num_epochs):
train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
test_acc = evaluate_accuracy(net, test_iter)
print(epoch + 1, train_metrics + (test_acc,))
train_loss, train_acc = train_metrics
assert train_loss < 0.5, train_loss
assert train_acc <= 1 and train_acc > 0.7, train_acc
assert test_acc <= 1 and test_acc > 0.7, test_acc
train_ch3
函数用于训练模型:这个函数接收神经网络net
、训练数据加载器train_iter
、测试数据加载器test_iter
、损失函数loss
、训练的迭代周期数num_epochs
和更新器updater
,在指定的迭代周期内进行模型训练,并输出训练过程中的损失和准确率。
def sgd(params,lr,batch_size):
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size
param.grad.zero_()
def updater(batch_size):
return sgd([W,b],lr,batch_size)
sgd
函数实现了随机梯度下降(SGD)优化算法,作为更新器来进行相关优化。
接下来是一些初始化操作和参数设置:
train_iter, test_iter = load_data_fashion_mnist(256)
num_inputs = 784
num_outputs = 10
batch_size = 256
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)
lr = 0.1
这里使用load_data_fashion_mnist
函数加载Fashion-MNIST数据集,并设置了输入和输出的维度、批次大小、权重W
和偏置b
的初始化方式以及学习率lr
。这里的784是因为图片本身大小为28*28,这里将其降维到一个维度,故变成了784*1的输入。
最后,调用train_ch3
函数进行模型训练:
num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)
这里设置了训练的迭代周期数为10,并传入了之前定义的模型、数据加载器、损失函数和更新器进行训练。
算法结果:
1 (0.790229809252421, 0.74525, 0.789)
2 (0.570518872833252, 0.8129166666666666, 0.8138)
3 (0.524523046875, 0.82505, 0.8062)
4 (0.5021363136927287, 0.8314833333333334, 0.8224)
5 (0.4847754037221273, 0.83685, 0.8238)
6 (0.4737337240219116, 0.84045, 0.8131)
7 (0.4652513662338257, 0.8430166666666666, 0.83)
8 (0.4579625264485677, 0.8442166666666666, 0.8056)
9 (0.4516355712890625, 0.8468333333333333, 0.8328)
10 (0.4473370404561361, 0.8480166666666666, 0.8276)
第一个数字为损失,第二个数字为准确度