对resnet不熟悉的,可以看一下这篇:
cifar-10+resnet
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch import optim
import os
import csv
from PIL import Image
import warnings
warnings.simplefilter('ignore')
from torchvision import datasets
#载入数据
trans = transforms.Compose((transforms.Resize((32,32)),transforms.ToTensor()))
train_set = datasets.MNIST('./num',train=True,transform=trans)
#mnist中的test_set一共有1万张照片,这里我们把前5000张用作validation_set,后5000张用作test_set
val_set = list(datasets.MNIST('./num',train=False,transform=trans))[:5000]
test_set = list(datasets.MNIST('./num',train=False,transform=trans))[5000:]
train_loader = DataLoader(train_set,batch_size=150,shuffle=True)
val_loader = DataLoader(val_set,batch_size=50,shuffle=True)
test_loader = DataLoader(test_set,batch_size=50,shuffle=True)
#构建resblock
class resblock(nn.Module):
def __init__(self,ch_in,ch_out,stride=1)