标题用torch实现手写数字识别
import os
import torch
import torchvision
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.utils.data import DataLoader
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000
def get_dataloader(train=True,batch_size=TRAIN_BATCH_SIZE):
transform_fn = Compose([
ToTensor(),
Normalize(
(0.1307,), (0.3081,))
])
dataset = torchvision.datasets.MNIST(root="./data",train=train,download=True,transform=transform_fn)
return DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)
class ImageNet(nn.Module):
def __init__(self):
super(ImageNet, self).__init__()
self.fc1 = nn.Linear(1*28*28, 28)
self.fc2 = nn.Linear(28, 10)
def forward(self, data):
features = data.view(data.size(0), 1*28*28)
features = self.fc1(features)
features = F.relu(features)
out = self.fc2(features)
return F.log_softmax(out, dim=-1)
model = ImageNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
if os.path.exists('./models/model.pkl'):
model.load_state_dict(torch.load('./models/model.pkl'))
optimizer.load_state_dict(torch.load('./models/optimizer.pkl'))
def train(epoch):
mode = True
model.train(mode=mode)
train_dataloader = get_dataloader(train=mode)
for idx, (data, target) in enumerate(train_dataloader):
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output,target)
loss.backward()
optimizer.step()
if idx % 10 == 0:
print('第%d轮次,损失值为%f' % (epoch, loss.item()))
if idx % 100 == 0:
torch.save(model.state_dict(), './models/model.pkl')
torch.save(optimizer.state_dict(), './models/optimizer.pkl')
def test():
test_loss = []
correct = []
model.eval()
test_dataloader = get_dataloader(train=False, batch_size=TEST_BATCH_SIZE)
with torch.no_grad():
for data, target in test_dataloader:
output = model(data)
test_loss.append(F.nll_loss(output, target))
pred = output.data.max(dim=1)[1]
correct.append(pred.eq(target).float().mean())
print('模型损失%f,平均准确率%f' % (np.mean(test_loss), np.mean(correct)))
if __name__ == '__main__':
for i in range(5):
train(i)
test()