使用fashion-MNIST演示PyTorch实现多层感知机的创建、训练和测试
导入依赖包
import torch
import torch.utils.data as Data
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.nn import init
import sys
加载数据集
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False, download=True, transform=transforms.ToTensor())
batch_size = 128
if sys.platform.startswith(&#