直接上代码:
import os
import torch
import torch.nn as nn
import pandas as pd
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
batch_size = 256
num_workers = 4
lr = 1e-4
epochs = 20
from torchvision import transforms
image_size = 28
data_transform = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor()
])
from torchvision import datasets
train_data = datasets.MNIST('./data',train=True,transform=data_transform)
test_data = datasets.MNIST('./data',train=False,transform=data_transform)
train_loader = DataLoader(dataset=train_data, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True)
test_loader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
import matplotlib.pypl