coding=utf-8
“”"
author:lei
function:
“”"
import os
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
import numpy as np
BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000
1、准备数据集
def get_dataloader(train=True):
transform_fn = Compose([
ToTensor(),
# mean和std形状 和 通道数相同
Normalize(mean=(0.1307,), std=(0.3081,))
])
dataset = MNIST(root="./data/", train=train, transform=transform_fn)
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
return data_loader
2、构建模型
class MnistModel(nn.Module):
def init(self):
super(MnistModel, self).init()
self.fc1 = nn.Linear(12828, 80)
self.fc2 = nn.Linear(80, 10)
def forward(self, input):
"""
:param input: [batch_size, 1, 28, 28]
:return:
"""
# 1、修改形状
x = input.view([-1, 1*28*28])
# 2、全连接操作
x = self.fc1(x)
# 3、进行激活函数的操作,形状不会发生变换
x = F.relu(x)
# 4、输出层
out = self.fc2(x)
return F.log_softmax(out, dim=-1)
model = MnistModel()
optimizer = Adam(model.parameters(), lr=0.001)
if os.path.exists("./model/model.pkl"):
model.load_state_dict(torch.load("./model/model.pkl"))
optimizer.load_state_dict(torch.load("./model/optimizer.pkl"))
def train(epoch):
“”"
实现训练的过程
“”"
data_loader = get_dataloader()
for idx, (input, target) in enumerate(data_loader):
optimizer.zero_grad() # 梯度归零
output = model(input) # 调用模型,得到预测值
# 得到损失
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step() # 梯度更新
if idx % 100 == 0:
print(epoch, idx, loss.item())
# 模型的保存
torch.save(model.state_dict(), "./model/model.pkl")
torch.save(optimizer.state_dict(), "./model/optimizer.pkl")
def test():
test_dataloader = get_dataloader(train=False)
loss_list = []
acc_list = []
for idx, (input, target) in enumerate(test_dataloader):
with torch.no_grad():
output = model(input)
cur_loss = F.nll_loss(output, target)
# 计算准确率
loss_list.append(cur_loss)
pred = output.data.max(dim=-1)[-1]
cur_acc = pred.eq(target).float().mean()
acc_list.append(cur_acc)
print(np.mean(acc_list), np.mean(loss_list))
if name == ‘main’:
# for i in range(3):
# train(i)
test()