Pytorch实现Resnet训练CIFAR10数据集(完整代码,包含resnet-50,resnet-101)
首先之前有写Pytorch的入门教程博客如果没有安装pytorch具体可转链接
废话不多说,直接上代码
这个代码使用CUDA 训练,如果不想使用GPU,可以将device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
注释掉,并且把所有后面用到的将tensor转到GPU的代码一起.to(device)
删掉
CIFAR10数据集下载不动的这里有下载好的网盘链接(提取码5tk8),直接解压存入代码文件上一层的文件夹中
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# 判断是否有GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 50 #50轮
batch_size = 50 #50步长
learning_rate = 0.01 #学习率0.01
# 图像预处理
transform = transforms.Compose([
transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor()])
# CIFAR-10 数据集下载
train_dataset = torchvision.datasets.CIFAR10(root='../data/',
train=True,
transform=transform,
download=True)
test_dataset = torchvision.datasets.CIFAR10(root='../data/',
train=False,
transform=transforms.ToTensor())
# 数据载入
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# 3x3 卷积定义
def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
# Resnet 的残差块
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(out_channels, out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
# ResNet定义
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 16
self.conv = conv3x3(3, 16)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU