Mnist分类任务:
-
网络基本构建与训练方法,常用函数解析
-
torch.nn.functional模块
-
nn.Module模块
import torch
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)#将数据转化为tensor类型
import torch.nn.functional as F
loss_func = F.cross_entropy#可以使用functional中的不用训练的函数
def model(xb):
return xb.mm(weights) + bias
from torch import nn
#定义全连接神经网络,继承module不用定义反向传播
class Mnist_NN(nn.Module):
def __init__(self):
super().__init__()
self.hidden1 = nn.Linear(784, 128)
self.hidden2 = nn.Linear(128, 256)
self.out = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.hidden1(x))
x = F.relu(self.hidden2(x))
x = self.out(x)
return x
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
#两个模块帮助你分批次使用Data
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):
return (
DataLoader(train_ds, batch_size=bs, shuffle=True),
DataLoader(valid_ds, batch_size=bs * 2),
)
from torch import optim
#定义model和优化器
def get_model():
model = Mnist_NN()
return model, optim.SGD(model.parameters(), lr=0.001)
#损失函数+反向传播
def loss_batch(model, loss_func, xb, yb, opt=None):
loss = loss_func(model(xb), yb)
if opt is not None:
loss.backward()
opt.step()
opt.zero_grad()
return loss.item(), len(xb)
import numpy as np
#训练过程
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
for step in range(steps):
model.train()#训练过程
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
model.eval()#测试过程
with torch.no_grad():
losses, nums = zip(
*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
)
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
#计算损失,是很多数据的平均值
print('当前step:'+str(step), '验证集损失:'+str(val_loss))
#主函数
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(10, model, loss_func, opt, train_dl, valid_dl)