Resnet20代码review

由于从github上找的一个VGG代码无法与原项目的resnet训练代码兼容,(训练结果不收敛)。

没有想到其他捷径去解决问题,所以把代码看一遍吧。

类的初始化

class resnet20(nn.Module):
        def __init__(self, num_class):
                super(resnet20, self).__init__()
                self.conv1   = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
                self.bn1     = nn.BatchNorm2d(16)
                self.relu    = nn.ReLU()

                self.res1 = self.make_layer(resblock, 3, 16, 16)
                self.res2 = self.make_layer(resblock, 3, 16, 32)
                self.res3 = self.make_layer(resblock, 3, 32, 64)

                self.avgpool = nn.AvgPool2d(8)
                self.fc      = nn.Linear(64, num_class)

                for m in self.modules():
                        if isinstance(m, nn.Conv2d):
                                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                                if m.bias is not None:
                                        nn.init.constant_(m.bias, 0)
                        elif isinstance(m, nn.BatchNorm2d):
                                nn.init.constant_(m.weight, 1)
                                nn.init.constant_(m.bias, 0)

                self.num_class = num_class

Resblock结构

在类的初始化中,值得一提的是resblock这个结构,它其实是resnet的核心,一个残差网络。


class resblock(nn.Module):
        def __init__(self, in_channels, out_channels, return_before_act):
                super(resblock, self).__init__()
                self.return_before_act = return_before_act
                self.downsample = (in_channels != out_channels)
                if self.downsample:
                        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
                        self.ds    = nn.Sequential(*[
                                                        nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
                                                        nn.BatchNorm2d(out_channels)
                                                        ])
                else:
                        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
                        self.ds    = None
                self.bn1   = nn.BatchNorm2d(out_channels)
                self.relu  = nn.ReLU()
                self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
                self.bn2   = nn.BatchNorm2d(out_channels)

        def forward(self, x):
                residual = x

                pout = self.conv1(x) # pout: pre out before activation
                pout = self.bn1(pout)
                pout = self.relu(pout)

                pout = self.conv2(pout)
                pout = self.bn2(pout)

                if self.downsample:
                        residual = self.ds(x)

                pout += residual
                out  = self.relu(pout)

                if not self.return_before_act:
                        return out
                else:
                        return pout, out

前向传播

resnet的前向传播也很容易理解,就是把包括reblock的各个结构串联到一起。

        def forward(self, x):
                # print(f"input shape {x.shape}")
                pstem = self.conv1(x) # pstem: pre stem before activation
                pstem = self.bn1(pstem)
                stem  = self.relu(pstem)
                stem  = (pstem, stem)

                rb1 = self.res1(stem[1])
                rb2 = self.res2(rb1[1])
                rb3 = self.res3(rb2[1])

                feat = self.avgpool(rb3[1])
                feat = feat.view(feat.size(0), -1)
                out  = self.fc(feat)
                # print(f"out shape {out.shape}")

                return stem, rb1, rb2, rb3, feat, out

训练

定义网络

在train_base.py中,可以看到main()函数中,通过函数define_tsnet,可以定义要使用的网络

def define_tsnet(name, num_class, cuda=True):
        if name == 'resnet20':
                net = resnet20(num_class=num_class)
        elif name == 'resnet110':
                net = resnet110(num_class=num_class)
        elif name == 'vgg16':
                net = vgg16_bn()
        else:
                raise Exception('model name does not exist.')

        if cuda:
                net = torch.nn.DataParallel(net).cuda()
        else:
                net = torch.nn.DataParallel(net)

        return net

参数初始化

参数初始化主要是为了保证使用训练数据归一化,

	# define loss functions
	if args.cuda:
		criterion = torch.nn.CrossEntropyLoss().cuda()
	else:
		criterion = torch.nn.CrossEntropyLoss()

	# define transforms
	if args.data_name == 'cifar10':
		dataset = dst.CIFAR10
		mean = (0.4914, 0.4822, 0.4465)
		std  = (0.2470, 0.2435, 0.2616)
	elif args.data_name == 'cifar100':
		dataset = dst.CIFAR100
		mean = (0.5071, 0.4865, 0.4409)
		std  = (0.2673, 0.2564, 0.2762)
	else:
		raise Exception('Invalid dataset name...')

	train_transform = transforms.Compose([
			transforms.Pad(4, padding_mode='reflect'),
			transforms.RandomCrop(32),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean,std=std)
		])
	test_transform = transforms.Compose([
			transforms.CenterCrop(32),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean,std=std)
		])

train函数

在进行了一系列设置后,接下来就可以进行训练了

def train(train_loader, net, optimizer, criterion, epoch):
	batch_time = AverageMeter()
	data_time  = AverageMeter()
	losses     = AverageMeter()
	top1       = AverageMeter()
	top5       = AverageMeter()

	net.train()

	end = time.time()
	for i, (img, target) in enumerate(train_loader, start=1):
		data_time.update(time.time() - end)
		#debug label
		# print(f"{i} --> {target}")

		if args.cuda:
			img = img.cuda(non_blocking=True)
			target = target.cuda(non_blocking=True)

		_, _, _, _, _, out = net(img)
		# out = net(img)  # for VGG
		loss = criterion(out, target)

		prec1, prec5 = accuracy(out, target, topk=(1,5))
		losses.update(loss.item(), img.size(0))
		top1.update(prec1.item(), img.size(0))
		top5.update(prec5.item(), img.size(0))

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		batch_time.update(time.time() - end)
		end = time.time()

		if i % args.print_freq == 0:
			log_str = ('Epoch[{0}]:[{1:03}/{2:03}] '
					   'Time:{batch_time.val:.4f} '
					   'Data:{data_time.val:.4f}  '
					   'loss:{losses.val:.4f}({losses.avg:.4f})  '
					   'prec@1:{top1.val:.2f}({top1.avg:.2f})  '
					   'prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(
					   epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time,
					   losses=losses, top1=top1, top5=top5))
			logging.info(log_str)

其中,criterion对应的是使用CrossEntropy作为输出。其他的基本为常规操作。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值