使用mnist数据集构建了一个多层感知机模型
import numpy as np
import torch
import torchvision
from torch.utils import data
from d2l import torch as d2l
from torchvision import transforms
import torch.nn as nn
batch_size = 256
def load_data_fashion_mnist(batch_size, resize=None):
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="./data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="./data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True),
data.DataLoader(mnist_test, batch_size, shuffle=False))
train_iter, test_iter = load_data_fashion_mnist(batch_size)
net = torch.nn.Sequential(nn.Flatten(),
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,10))
def init_werights(m):
if(type(m) == nn.Linear):
torch.nn.init.normal_(m.weight,std=0.01)
net.apply(init_werights)
def accuracy(X,y):
if(len(X.shape)>1 and X.shape[1]>1):
pred = X.argmax(axis=1)
cmp = pred.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())
cross_n = torch.nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(),lr=0.02)
epochs = 10
# 模型训练
for epoch in range(epochs):
if(net == torch.nn.Module):
net.train()
true_data,data_len = 0,0
true_data_t,data_len_t = 0,0
for X,y in train_iter:
preb = net(X)
trainer.zero_grad()
l = cross_n(preb,y)
if isinstance(trainer, torch.optim.Optimizer):
l.backward()
trainer.step()
# 计算准确率
acc_num = accuracy(preb,y)
true_data += acc_num
data_len += len(X)
# 测试
net.eval()
with torch.no_grad():
for X,y in test_iter:
preb = net(X)
# 计算准确率
acc_num = accuracy(preb, y)
true_data_t += acc_num
data_len_t += len(X)
print("epoch =", epoch + 1, ' train acc=', true_data / data_len, ' test acc=', true_data_t / data_len_t)
true_data, data_len = 0, 0
true_data_t, data_len_t = 0, 0
运行结果:
epoch = 1 train acc= 0.4238 test acc= 0.5785
epoch = 2 train acc= 0.6531666666666667 test acc= 0.6685
epoch = 3 train acc= 0.7029833333333333 test acc= 0.7203
epoch = 4 train acc= 0.7425333333333334 test acc= 0.7493
epoch = 5 train acc= 0.7692 test acc= 0.7678
epoch = 6 train acc= 0.78745 test acc= 0.7827
epoch = 7 train acc= 0.8011833333333334 test acc= 0.7911
epoch = 8 train acc= 0.8106833333333333 test acc= 0.7996
epoch = 9 train acc= 0.8171833333333334 test acc= 0.8012
epoch = 10 train acc= 0.8227833333333333 test acc= 0.808