最终模型train步骤: train.py+new_model.py+decoder.py

主函数:train.py

def main():
    warnings.filterwarnings('ignore')
    assert torch.cuda.is_available()
    torch.backends.cudnn.benchmark = True
    # 获取训练参数。
    args = obtain_retrain_autodeeplab_args()
    model_fname = 'data/deeplab_{0}_{1}_v3_{2}_epoch%d.pth'.format(args.backbone, args.dataset, args.exp)
    # 当数据集为pascal时。
    if args.dataset == 'pascal':
        raise NotImplementedError
    # 当数据集为cityscapes时。
    elif args.dataset == 'cityscapes':
        # 设置参数。
        kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True}
        # 此时的-AutoDeeplab参数为train。区别search模式,这里返回数据集加载结果和总类别数。
        dataset_loader, num_classes = dataloaders.make_data_loader(args, **kwargs)
        # 调整参数总类别数。
        args.num_classes = num_classes
    else:
        raise ValueError('Unknown dataset: {}'.format(args.dataset))

    # backbone参数只能设置为autodeeplab。
    if args.backbone == 'autodeeplab':
        model = Retrain_Autodeeplab(args)
    else:
        raise ValueError('Unknown backbone: {}'.format(args.backbone))
    
    # 建立模型的评估标准。
    if args.criterion == 'Ohem':
        args.thresh = 0.7
        args.crop_size = [args.crop_size, args.crop_size] if isinstance(args.crop_size, int) else args.crop_size
        args.n_min = int((args.batch_size / len(args.gpu) * args.crop_size[0] * args.crop_size[1]) // 16)
    criterion = build_criterion(args)
    
    # 训练模型。
    model = nn.DataParallel(model).cuda()
    model.train()
    # 判断bn层参数要不要参加反向传播。
    if args.freeze_bn:
        for m in model.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.weight.requires_grad = False
                m.bias.requires_grad = False
    # 选择一种优化方法——SGD。
    optimizer = optim.SGD(model.module.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=0.0001)
    
    # 最大迭代次数——每个epoch需要迭代len(dataset_loader)次,所以相乘就是最终结果。
    max_iteration = len(dataset_loader) * args.epochs
    scheduler = Iter_LR_Scheduler(args, max_iteration, len(dataset_loader))
    start_epoch = 0

    # 如果是从之前的某个中断开始的话,恢复中断值。
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {0}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('=> loaded checkpoint {0} (epoch {1})'.format(args.resume, checkpoint['epoch']))
        else:
            raise ValueError('=> no checkpoint found at {0}'.format(args.resume))
    
    # 开始训练,循环epoch值。
    for epoch in range(start_epoch, args.epochs):
        # 初始化losses值。
        losses = AverageMeter()
        # 枚举采样的样本sample。
        for i, sample in enumerate(dataset_loader):
            # 计算当前是第几次迭代。
            cur_iter = epoch * len(dataset_loader) + i
            scheduler(optimizer, cur_iter)
            # 获取输入和标签值。
            inputs = sample['image'].cuda()
            target = sample['label'].cuda()
            # 获取当前模型对输入图像的一次前向计算输出值。
            outputs = model(inputs)
            # 计算损失值。
            loss = criterion(outputs, target)
            if np.isnan(loss.item()) or np.isinf(loss.item()):
                pdb.set_trace()
            # 更新当前epoch的整体损失值。
            losses.update(loss.item(), args.batch_size)

            # 损失值反向传播。
            loss.backward()
            # 优化器迭代一步。
            optimizer.step()
            # 清空当前计算的梯度值。
            optimizer.zero_grad()
            # 输出当前迭代结果。
            print('epoch: {0}\t''iter: {1}/{2}\t''lr: {3:.6f}\t''loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
                epoch + 1, i + 1, len(dataset_loader), scheduler.get_lr(optimizer), loss=losses))
        
        if epoch < args.epochs - 50:
            # 当当前epoch值小于某一阙值时。
            # 每50个epoch就更新一下当前模型的最优值记录。
            if epoch % 50 == 0:
                torch.save({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, model_fname % (epoch + 1))
        else:
            # 当当前epoch值大于某一阙值时,直接更新保存当前模型的记录。
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_fname % (epoch + 1))

        print('reset local total loss!')

Retrain_Autodeeplab类

class Retrain_Autodeeplab(nn.Module):
    def __init__(self, args):
        super(Retrain_Autodeeplab, self).__init__()
        # 设置各层次的filter大小,BatchNorm的方式。
        filter_param_dict = {0: 1, 1: 2, 2: 4, 3: 8}
        BatchNorm2d = ABN if args.use_ABN else NaiveBN
        if (not args.dist and args.use_ABN) or (args.dist and args.use_ABN and dist.get_rank() == 0):
            print("=> use ABN!")
        
        # 选择网络的整体结构[network_arch],cell结构[cell_arch]和网络的中间路径[network_path]。
        if args.net_arch is not None and args.cell_arch is not None:
            net_arch, cell_arch = np.load(args.net_arch), np.load(args.cell_arch)
        else:
            network_arch, cell_arch, network_path = get_default_arch()
        # 建立encoder模型。
        self.encoder = newModel(network_arch, cell_arch, args.num_classes, 12, args.filter_multiplier, BatchNorm=BatchNorm2d, args=args)
        # 建立最后的ASPP层。
        self.aspp = ASPP(args.filter_multiplier * args.block_multiplier * filter_param_dict[network_path[-1]],
                         256, args.num_classes, conv=nn.Conv2d, norm=BatchNorm2d)
        # 建立decoder模型。
        self.decoder = Decoder(args.num_classes, filter_multiplier=args.filter_multiplier * args.block_multiplier,
                               args=args, last_level=network_path[-1])

    def forward(self, x):
        encoder_output, low_level_feature = self.encoder(x)
        high_level_feature = self.aspp(encoder_output)
        decoder_output = self.decoder(high_level_feature, low_level_feature)
        return nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', align_corners=True)(decoder_output)

    def get_params(self):
        back_bn_params, back_no_bn_params = self.encoder.get_params()
        tune_wd_params = list(self.aspp.parameters()) \
                         + list(self.decoder.parameters()) \
                         + back_no_bn_params
        no_tune_wd_params = back_bn_params
        return tune_wd_params, no_tune_wd_params

new_model.py newModel类

class newModel(nn.Module):
    def __init__(self, network_arch, cell_arch, num_classes, num_layers, filter_multiplier=20, lock_multiplier=5, step=5, cell=Cell,
                 BatchNorm=NaiveBN, args=None):
        super(newModel, self).__init__()
        # 设置参数。
        self.args = args
        self._step = step
        self.cells = nn.ModuleList()
        self.network_arch = torch.from_numpy(network_arch)
        self.cell_arch = torch.from_numpy(cell_arch)
        self._num_layers = num_layers
        self._num_classes = num_classes
        self._block_multiplier = args.block_multiplier
        self._filter_multiplier = args.filter_multiplier
        self.use_ABN = args.use_ABN
        initial_fm = 128 if args.initial_fm is None else args.initial_fm
        half_initial_fm = initial_fm // 2
        # 配置stem0结构。
        self.stem0 = nn.Sequential(
            nn.Conv2d(3, half_initial_fm, 3, stride=2, padding=1),
            BatchNorm(half_initial_fm)
        )
        # 配置stem1结构。
        self.stem1 = nn.Sequential(
            nn.Conv2d(half_initial_fm, half_initial_fm, 3, padding=1),
            BatchNorm(half_initial_fm)
        )
        # 配置stem2结构。
        ini_initial_fm = half_initial_fm
        self.stem2 = nn.Sequential(
            nn.Conv2d(half_initial_fm, initial_fm, 3, stride=2, padding=1),
            BatchNorm(initial_fm)
        )
        
        filter_param_dict = {0: 1, 1: 2, 2: 4, 3: 8}
        # 遍历所有的层。
        for i in range(self._num_layers):
            # network_arch是12x4x3维度的,network_arch[i]是第i层结构,维度为4x3。
            # dim=1,因此按行求和,得4x1的向量,只有选中的那级值为1。
            level_option = torch.sum(self.network_arch[i], dim=1)
            prev_level_option = torch.sum(self.network_arch[i - 1], dim=1)
            prev_prev_level_option = torch.sum(self.network_arch[i - 2], dim=1)
            
            # level_option是第i层的级向量,level是找到有效的级数。
            level = torch.argmax(level_option).item()
            prev_level = torch.argmax(prev_level_option).item()
            prev_prev_level = torch.argmax(prev_prev_level_option).item()
            
            # 理解,本质上每一级的操作都是一个cell结构。
            if i == 0:
                downup_sample = - torch.argmax(torch.sum(self.network_arch[0], dim=1))
                _cell = cell(self._step, self._block_multiplier, 
                             ini_initial_fm / args.block_multiplier,
                             initial_fm / args.block_multiplier,
                             self.cell_arch, self.network_arch[i],
                             self._filter_multiplier *
                             filter_param_dict[level],
                             downup_sample, self.args)
            else:
                three_branch_options = torch.sum(self.network_arch[i], dim=0)
                downup_sample = torch.argmax(three_branch_options).item() - 1
                if i == 1:
                    _cell = cell(self._step, self._block_multiplier,
                                 initial_fm / args.block_multiplier,
                                 self._filter_multiplier * filter_param_dict[prev_level],
                                 self.cell_arch, self.network_arch[i],
                                 self._filter_multiplier *
                                 filter_param_dict[level],
                                 downup_sample, self.args)
                else:
                    _cell = cell(self._step, self._block_multiplier,
                                 self._filter_multiplier * filter_param_dict[prev_prev_level],
                                 self._filter_multiplier * filter_param_dict[prev_level],
                                 self.cell_arch, self.network_arch[i],
                                 self._filter_multiplier *
                                 filter_param_dict[level], downup_sample, self.args)

            self.cells += [_cell]
    
    # 前向传播。
    def forward(self, x):
        stem = self.stem0(x)
        stem0 = self.stem1(stem)
        stem1 = self.stem2(stem0)
        # 获取最近的两个输出值,作为cell的输入。
        two_last_inputs = (stem0, stem1)
        for i in range(self._num_layers):
            two_last_inputs = self.cells[i](two_last_inputs[0], two_last_inputs[1])
            if i == 2:
                low_level_feature = two_last_inputs[1]
        last_output = two_last_inputs[-1]
        # 返回一个低级别的输出和一个高级别的暑促。
        return last_output, low_level_feature

decoder.py Decoder类

class Decoder(nn.Module):
    def __init__(self, num_classes, filter_multiplier, BatchNorm=NaiveBN, args=None, last_level=0):
        super(Decoder, self).__init__()
        low_level_inplanes = filter_multiplier
        C_low = 48
        # 构建卷积和BN层。
        self.conv1 = nn.Conv2d(low_level_inplanes, C_low, 1, bias=False)
        self.bn1 = BatchNorm(48)
        # 构建最后的卷积操作层。
        self.last_conv = nn.Sequential(nn.Conv2d(304,256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.Dropout(0.5),
                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.Dropout(0.1),
                                       nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
        self._init_weight()

    def forward(self, x, low_level_feat):
        # 对低层网络层的输出进行卷积和BN操作,使维度相符。
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        # 进行插值操作。
        x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        # 将低级和高级的特征组合到一起。
        x = torch.cat((x, low_level_feat), dim=1)
        # 进行最后的一系列卷积操作。
        x = self.last_conv(x)
        return x
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值