1.将数据加载到迭代器中并设置batch_size
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size = batch_size)
2.查看数据格式
for X, y in train_dataloader:
print(X,y)
3.设置cuda设备
device = "cuda" if torch.cuda.is_available() else "cpu"
4.根据数据格式设置对应的模型
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten() # 用于数据降维
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(), #激活函数
nn.Linear(512, 512), # 线性层
nn.ReLU(),
nn.Linear(512, 10)
)
# 创建forward方法
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
5.将模型装入cuda device中
model = NeuralNetwork().to(device)
6.定义损失函数
loss_fn = nn.CrossEntropyLoss()
7.定义优化器,传入模型参数和学习率
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
8.定义训练函数
8.1 model.train()
8.2 迭代数据集,将特征值和标签值都装入device
8.2 训练得到预测值:pred = model(X)
8.3 将预测值和真实值传入损失函数,得到误差:loss = loss_fn(pred, y)
8.4 梯度置0,optimizer.zero_grad()
8.5 误差回传:loss.backward()
8.6 更新参数:optimizer.step()
8.7 评估
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train() #
for batch, (X, y) in enumerate(dataloader):
# print(batch)
X, y = X.to(device), y.to(device)
# print(X, y)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad() #梯度置0
loss.backward() # 误差回传
optimizer.step() # 更新参数
# if batch % 100 == 0:
# loss, current = loss.item(), batch * len(X)
# print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
9.定义测试函数
9.1 model.eval()
9.2 强制后面的计算图不进行重新构建:with torch.no_grad()
9.3 迭代数据集,将特征值和标签值装入device
9.4 训练:pred = model(X)
9.5 得到测试误差:test_loss += loss_fn(pred, y).item()
9.6 计算正确数量:
9.7 评估
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")