目录
不难,但是一些实现方法比较有特色
softmax从零开始实现
导包
import torch
from torch import nn
import numpy as np
from IPython import display
import torchvision.transforms as transforms
import torchvision
from torch.utils import data
import sys
导入数据
def load_data_fashion_mnist(batch_size, resize=None):
"""Download the Fashion-MNIST dataset and then load it into memory.
Defined in :numref:`sec_fashion_mnist`"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
def get_dataloader_workers():
"""Use 4 or 0 processes to read the data.
"""
return 0 if sys.platform.startswith('win') else 4
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)
初始化模型参数
展平每个图像,把它们看作长度为784的向量
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)
定义softmax操作
def softmax(X):
# X_max = X.max(axis=1)[0].reshape(-1, 1)
X_max = torch.max(X, axis=1)[0].reshape(-1, 1)
X -= X_max
X_exp = torch.exp(X)
partition = X_exp.sum(1, keepdim=True)
return X_exp / partition # 这里应用了广播机制
定义模型
w形状:特征数 * 目标类别数
def net(X):
return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)
损失函数
交叉熵采用真实标签的预测概率的负对数似然。y不是one-hot编码方式,而是标签序号
def cross_entropy(y_hat, y):
return - torch.log(y_hat[range(len(y_hat)), y])
优化函数
lr = 0.1
def updater(batch_size):
return sgd([W, b], lr, batch_size)
def sgd(params, lr, batch_size):
"""Minibatch stochastic gradient descent.
"""
with torch.no_grad():
for param in params:
param -= lr * param.grad / batch_size # 为什么要除以batch_size?
param.grad.zero_()
分类精度
由于等式运算符“==
”对数据类型很敏感, 因此我们将y_hat
的数据类型转换为与y
的数据类型一致。 结果是一个包含0(错)和1(对)的张量。 最后,我们求和会得到正确预测的数量