def load_data(batch_size,resize=None):
trans = [transforms.ToTensor()]
if resize:
trans.insert(0,transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=False)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=False)
return data.DataLoader(mnist_train,batch_size),data.DataLoader(mnist_test,batch_size)
就是把多个transform类集成到一起了,像nn.Sequential()