由于从github上找的一个VGG代码无法与原项目的resnet训练代码兼容,(训练结果不收敛)。
没有想到其他捷径去解决问题,所以把代码看一遍吧。
类的初始化
class resnet20(nn.Module):
def __init__(self, num_class):
super(resnet20, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
self.res1 = self.make_layer(resblock, 3, 16, 16)
self.res2 = self.make_layer(resblock, 3, 16, 32)
self.res3 = self.make_layer(resblock, 3, 32, 64)
self.avgpool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_class)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
self.num_class = num_class
Resblock结构
在类的初始化中,值得一提的是resblock这个结构,它其实是resnet的核心,一个残差网络。
class resblock(nn.Module):
def __init__(self, in_channels, out_channels, return_before_act):
super(resblock, self).__init__()
self.return_before_act = return_before_act
self.downsample = (in_channels != out_channels)
if self.downsample:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
self.ds = nn.Sequential(*[
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(out_channels)
])
else:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.ds = None
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
residual = x
pout = self.conv1(x) # pout: pre out before activation
pout = self.bn1(pout)
pout = self.relu(pout)
pout = self.conv2(pout)
pout = self.bn2(pout)
if self.downsample:
residual = self.ds(x)
pout += residual
out = self.relu(pout)
if not self.return_before_act:
return out
else:
return pout, out
前向传播
resnet的前向传播也很容易理解,就是把包括reblock的各个结构串联到一起。
def forward(self, x):
# print(f"input shape {x.shape}")
pstem = self.conv1(x) # pstem: pre stem before activation
pstem = self.bn1(pstem)
stem = self.relu(pstem)
stem = (pstem, stem)
rb1 = self.res1(stem[1])
rb2 = self.res2(rb1[1])
rb3 = self.res3(rb2[1])
feat = self.avgpool(rb3[1])
feat = feat.view(feat.size(0), -1)
out = self.fc(feat)
# print(f"out shape {out.shape}")
return stem, rb1, rb2, rb3, feat, out
训练
定义网络
在train_base.py中,可以看到main()函数中,通过函数define_tsnet,可以定义要使用的网络
def define_tsnet(name, num_class, cuda=True):
if name == 'resnet20':
net = resnet20(num_class=num_class)
elif name == 'resnet110':
net = resnet110(num_class=num_class)
elif name == 'vgg16':
net = vgg16_bn()
else:
raise Exception('model name does not exist.')
if cuda:
net = torch.nn.DataParallel(net).cuda()
else:
net = torch.nn.DataParallel(net)
return net
参数初始化
参数初始化主要是为了保证使用训练数据归一化,
# define loss functions
if args.cuda:
criterion = torch.nn.CrossEntropyLoss().cuda()
else:
criterion = torch.nn.CrossEntropyLoss()
# define transforms
if args.data_name == 'cifar10':
dataset = dst.CIFAR10
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)
elif args.data_name == 'cifar100':
dataset = dst.CIFAR100
mean = (0.5071, 0.4865, 0.4409)
std = (0.2673, 0.2564, 0.2762)
else:
raise Exception('Invalid dataset name...')
train_transform = transforms.Compose([
transforms.Pad(4, padding_mode='reflect'),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=mean,std=std)
])
test_transform = transforms.Compose([
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize(mean=mean,std=std)
])
train函数
在进行了一系列设置后,接下来就可以进行训练了
def train(train_loader, net, optimizer, criterion, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
net.train()
end = time.time()
for i, (img, target) in enumerate(train_loader, start=1):
data_time.update(time.time() - end)
#debug label
# print(f"{i} --> {target}")
if args.cuda:
img = img.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
_, _, _, _, _, out = net(img)
# out = net(img) # for VGG
loss = criterion(out, target)
prec1, prec5 = accuracy(out, target, topk=(1,5))
losses.update(loss.item(), img.size(0))
top1.update(prec1.item(), img.size(0))
top5.update(prec5.item(), img.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
log_str = ('Epoch[{0}]:[{1:03}/{2:03}] '
'Time:{batch_time.val:.4f} '
'Data:{data_time.val:.4f} '
'loss:{losses.val:.4f}({losses.avg:.4f}) '
'prec@1:{top1.val:.2f}({top1.avg:.2f}) '
'prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(
epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time,
losses=losses, top1=top1, top5=top5))
logging.info(log_str)
其中,criterion对应的是使用CrossEntropy作为输出。其他的基本为常规操作。