(1)数据集
FashionMNIST数据集,作为经典的MNIST数据集的现代替代品的数据集,是衣物分类数据集,由Zalando(一家德国的在线时尚零售商)发布。FashionMNIST数据集和MNIST相比。图片尺寸相同,数据量相同,也同样是10分类,是MNIST的困难版本
(2)DNN(深度神经网络)
DNN可以理解为有多个隐藏层的神经网络,叫做深度神经网络(Deep Neural Network),DNN按不同层的位置划分,内部的神经网络层可以分为三类,输入层、隐藏层和输出层,如下图示例,一般来说第一层是输入层,最后一层是输出层,而中间的层数都是隐藏层。
(3)代码与结果
import torch
import torchvision
from torch import nn
from torch.utils import data
from torchvision import transforms
from torch.optim import Adam
class DNN(nn.Module):
def __init__(self):
super(DNN, self).__init__()
self.linear1=nn.Linear(784,1024)
self.linear2=nn.Linear(1024,512)
self.linear3=nn.Linear(512,256)
self.linear4=nn.Linear(256,10)
self.dropout=nn.Dropout(0.5)
self.ReLu=nn.ReLU()
self.flatten=nn.Flatten()
self.softmax=nn.Softmax()
def forward(self,x):
x=self.flatten(x)
x=self.linear1(x)
x=self.ReLu(x)
x=self.dropout(x)
x=self.linear2(x)
x=self.ReLu(x)
x=self.dropout(x)
x=self.linear3(x)
x=self.ReLu(x)
x=self.dropout(x)
x=self.linear4(x)
x=self.softmax(x)
return x
batch_size = 256
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
traindata = data.DataLoader(mnist_train, batch_size, shuffle=True)
testdata=data.DataLoader(mnist_test,100)
net=DNN()
net.to(device)
lr=0.0001
epoch=25
Optim=Adam(net.parameters(),lr=lr)
criterion=nn.CrossEntropyLoss()
totalaccurate=0
net.train()
for i in range(epoch):
total_train=0
for data in traindata:
img,label=data
with torch.no_grad():
img = img.to(device)
label = label.to(device)
Optim.zero_grad()
predict=net(img)
train_loss=criterion(predict,label)
train_loss.backward()
Optim.step()
total_train+=train_loss
print("训练集上的损失:{}".format(total_train))
net.eval()
for data in testdata:
img,lable=data
with torch.no_grad():
img=img.to(device)
lable=lable.to(device)
result=net(img)
accurate=(result.argmax(1)==lable).sum()
totalaccurate+=accurate
print("验证集的正确率为:{:.1%}".format(totalaccurate/10000))