ResNet34
代码
u"""ResNet34训练学习CIFAR10"""
__author__ = 'zhengbiqing 460356155@qq.com'
import torch as t
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
import datetime
import argparse
WORKERS = 4
PARAS_FN = 'cifar_resnet_params.pkl'
ROOT = './'
loss_func = nn.CrossEntropyLoss()
best_acc = 0
global_train_acc = []
global_test_acc = []
'''
残差块
in_channels, out_channels:残差块的输入、输出通道数
对第一层,in out channel都是64,其他层则不同
对每一层,如果in out channel不同, stride是1,其他层则为2
'''
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
if in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
nn.BatchNorm2d(out_channels)
)
else:
self.downsample = None
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
'''
定义网络结构
'''
class ResNet34(nn.Module):
def __init__(self, block):
super(ResNet34, self).__init__()
self.first = nn.Sequential(
nn.Conv2d(3, 64, 7, 2, 3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 1, 1)
)
self.layer1 = self.make_layer(block, 64, 64, 3, 1)
self.layer2 = self.make_layer(block, 64, 128, 4, 2)
self.layer3 = self.make_layer(block, 128, 256, 6, 2)
self.layer4 = self.make_layer(block, 256, 512, 3, 2)
self.avg_pool = nn.AvgPool2d(2)
self.fc = nn.Linear(512, 10)
def make_layer(self, block, in_channels, out_channels, block_num, stride):
layers = []
layers.append(block(in_channels, out_channels, stride))
for i in range(block_num - 1):
layers.append(block(out_channels, out_channels, 1))
return nn.Sequential(*layers)
def forward(self, x):
x = self.first(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avg_pool(x)
x = x.view(x.size()[0], -1)
x = self.fc(x)
return x
'''
训练并测试网络
net:网络模型
train_data_load:训练数据集
optimizer:优化器
epoch:第几次训练迭代
log_interval:训练过程中损失函数值和准确率的打印频率
'''
def net_train(net, train_data_load, optimizer, epoch, log_interval):
net.train()
begin = datetime.datetime.now()
total = len(train_data_load.dataset)
train_loss = 0
ok = 0
for i, data in enumerate(train_data_load, 0):
img, label = data
img, label = img.cuda(), label.cuda()
optimizer.zero_grad()
outs = net(img)
loss = loss_func(outs, label)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = t.max(outs.data, 1)
ok += (predicted == label).sum()
if (i + 1) % log_interval == 0:
traind_total = (i + 1) * len(label)
acc = 100. * ok / traind_total
global_train_acc.append(acc)
end = datetime.datetime.now()
print('one epoch spend: ', end - begin)
'''
用测试集检查准确率
'''
def net_test(net, test_data_load, epoch):
net.eval()
ok = 0
for i, data in enumerate(test_data_load):
img, label = data
img, label = img.cuda(), label.cuda()
outs = net(img)
_, pre = t.max(outs.data, 1)
ok += (pre == label).sum()
acc = ok.item() * 100. / (len(test_data_load.dataset))
print('EPOCH:{}, ACC:{}\n'.format(epoch, acc))
global_test_acc.append(acc)
global best_acc
if acc > best_acc:
best_acc = acc
'''
显示数据集中一个图片
'''
def img_show(dataset, index):
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
show = ToPILImage()
data, label = dataset[index]
print('img is a ', classes[label])
show((data + 1) / 2).resize((100, 100)).show()
'''
显示训练准确率、测试准确率变化曲线
'''
def show_acc_curv(ratio):
train_x = list(range(len(global_train_acc)))
train_y = global_train_acc
train_x = t.tensor(train_x, device = 'cpu')
train_y = t.tensor(train_y, device = 'cpu')
test_x = train_x[ratio-1::ratio]
test_y = global_test_acc
test_x = t.tensor(test_x, device = 'cpu')
test_y = t.tensor(test_y, device = 'cpu')
plt.title('CIFAR10 RESNET34 ACC')
plt.plot(train_x, train_y, color='green', label='training accuracy')
plt.plot(test_x, test_y, color='red', label='testing accuracy')
plt.legend()
plt.xlabel('iterations')
plt.ylabel('accs')
plt.show()
def main():
parser = argparse.ArgumentParser(description='PyTorch CIFA10 ResNet34 Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
help='input batch size for testing (default: 100)')
parser.add_argument('--epochs', type=int, default=90, metavar='N',
help='number of epochs to train (default: 200)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 0.1)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status (default: 10)')
parser.add_argument('--no-train', action='store_true', default=False,
help='If train the Model')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args(args=[])
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
transform = tv.transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
train_load = t.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=WORKERS)
test_load = t.utils.data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False, num_workers=WORKERS)
net = ResNet34(ResBlock).cuda()
print(net)
net = nn.DataParallel(net)
cudnn.benchmark = True
if args.no_train:
net.load_state_dict(t.load(PARAS_FN))
net_test(net, test_load, 0)
return
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
start_time = datetime.datetime.now()
for epoch in range(1, args.epochs + 1):
net_train(net, train_load, optimizer, epoch, args.log_interval)
net_test(net, test_load, epoch)
end_time = datetime.datetime.now()
global best_acc
print('CIFAR10 pytorch ResNet34 Train: EPOCH:{}, BATCH_SZ:{}, LR:{}, ACC:{}'.format(args.epochs, args.batch_size, args.lr, best_acc))
print('train spend time: ', end_time - start_time)
ratio = len(train_data) / args.batch_size / args.log_interval
ratio = int(ratio)
show_acc_curv(ratio)
if args.save_model:
t.save(net.state_dict(), PARAS_FN)
if __name__ == '__main__':
main()
训练结果
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz
100%
170498071/170498071 [00:03<00:00, 70439337.53it/s]
Extracting ./cifar-10-python.tar.gz to ./
/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py:554: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
warnings.warn(_create_warning_msg(
ResNet34(
(first): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
)
(layer1): Sequential(
(0): ResBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): ResBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): ResBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): ResBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): ResBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): ResBlock(
(conv1): Conv2d(128,