一、导入相关库
import torch
import numpy as np
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l
print(torch.__version__)
二、读取数据
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
三、定义模型参数
输入数据(256,784),W1=(784,256),b1=(1,256),W2=(256,10),b2=(1,10)
num_inputs, num_outputs, num_hiddens = 784, 10, 256
W1 = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_hiddens)), dtype=torch.float)
b1 = torch.zeros(num_hiddens, dtype=torch.float)
W2 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens, num_outputs)), dtype=torch.float)
b2 = torch.zeros(num_outputs, dtype=torch.float)
params = [W1, b1, W2, b2]
for param in params:
param.requires_grad_(requires_grad=True)
三、定义激活函数和模型
1、定义激活函数
在torch.max()函数中,input是输入张量,other参数是用来比较的,作用是张量中每个值变成该值和other的最大值。
def relu(X):
return torch.max(input=X, other=torch.tensor(0.0))
2、定义模型
def net(X):
X = X.view((-1, num_inputs))
H = relu(torch.matmul(X, W1) + b1)
return torch.matmul(H, W2) + b2
四、定义损失函数
使用方法和之前类似
loss = torch.nn.CrossEntropyLoss()
五、训练模型
num_epochs, lr = 5, 100.0
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)