这里采用ResNet网络对CIFRI-10数据集进行分类训练,最终通过测试集进行测试,(代码主要解释以及网络已经在注释中给出)具体代码如下:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
#简要介绍本次ResNet网络,说在前面,这里定义的conv3x3都是padding=1,保证卷积前后图像宽高为一样大
'''
图片为3,32,32
通过第一层卷积(输入输出通道为3,16)变为16,,32,32
下一步就是layer1
第一个ResBlock(输入输出通道为16,16)的stride是设置的为1,所以卷积图片长宽不变16,32,32
第二个ResBlock(输入输出通道为16,16)的stride设置的也为1,所以卷积之后图片长宽也不变16,32,32
下一步是layer2
第一个ResBlock(输入输出通道为16,32)的stride是设置的为1,但输出通道设为32,所以卷积图片长宽不变,32,16,16
第二个ResBlock(输入输出通道为32,32)的stride设置的为2,所以卷积之后图片长宽也减半32,16,16
下一步是layer3
第一个ResBlock(输入输出通道为32,64)的stride是设置的为1,但输出通道设为64,所以卷积图片长宽不变,64,16,16
第二个ResBlock(输入输出通道为64,64)的stride设置的为2,所以卷积之后图片长宽减半64,8,8
平均池化层(核为8)所以变为64,1,1
这里就定义了一个线性层(输入输出为64,1)记得用view函数转换维度
'''
#超参数设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 80
batch_size = 100
learning_rate = 0.001
#图像数据处理
transform = transforms.Compose([
#padding填充值为4
transforms.Pad(4),
#图像随机翻转
transforms.RandomHorizontalFlip(),
#将图像进行随机选取中心点切割
transforms.RandomCrop(32),
transforms.ToTensor()
])
# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='data/',
train=True,
transform=transform,
download=True)
test_dataset = torchvision.datasets.CIFAR10(root='data/',
train=False,
transform=transforms.ToTensor())
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
#定义3×3卷积
def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
#定义一个残差块
class ResidualBlocks(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlocks, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = conv3x3(out_channels, out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if(self.downsample != None):
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
#定义ResNet
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 16
self.conv = conv3x3(3, 16)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self.make_layer(block, 16, layers[0])
self.layer2 = self.make_layer(block, 32, layers[1], 2)
self.layer3 = self.make_layer(block, 64, layers[2], 2)
self.avg_pool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_classes)
def make_layer(self, block, out_channels, blocks, stride=1):
downsample = None
if(stride != 1) or (self.in_channels != out_channels):
downsample = nn.Sequential(
conv3x3(self.in_channels, out_channels, stride=stride),
nn.BatchNorm2d(out_channels)
)
layers = []
layers.append(block(self.in_channels, out_channels, stride, downsample))
self.in_channels = out_channels
for i in range(1, blocks):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.avg_pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
model = ResNet(ResidualBlocks, [2, 2, 2]).to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# For updating learning rate
def update_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
total_step = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if((i + 1) % 100 == 0):
print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch+1, num_epochs, i+1, total_step, loss.item()))
if (epoch + 1) % 20 == 0:
curr_lr /= 3
update_lr(optimizer, curr_lr)
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
torch.save(model.state_dict(), 'resnet.ckpt')