3.7 softmax回归的简洁实现
我们在3.3节(线性回归的简洁实现)中已经了解了使用Pytorch实现模型的便利。下面,让我们再次使用Pytorch来实现一个softmax回归模型。首先导入所需的包或模块。
import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
Copy to clipboardErrorCopied
3.7.1 获取和读取数据
我们仍然使用Fashion-MNIST数据集和上一节中设置的批量大小。
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
Copy to clipboardErrorCopied
3.7.2 定义和初始化模型
在3.4节(softmax回归)中提到,softmax回归的输出层是一个全连接层,所以我们用一个线性模块就可以了。因为前面我们数据返回的每个batch样本x
的形状为(batch_size, 1, 28, 28), 所以我们要先用view()
将x
的形状转换成(batch_size, 784)才送入全连接层。
num_inputs =