PDARTS 网络结构搜索程序分析

PDARTS 即 Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation,是对 DARTS 的改进。DARTS 内存占用过高,训练不了较大的模型;PDARTS 将训练划分为3个阶段,逐步搜索,在增加网络深度的同时缩减操作种类。构造3次网络拉长了训练周期,过程如下图所示:

pipeline2
此外,算法还对筛选细节进行了控制。chenxin061/pdarts 修改自 quark0/darts,主函数逻辑稍显复杂。

train_search.py

    start_time = time.time()
    main() 
    end_time = time.time()
    duration = end_time - start_time
    logging.info('Total searching time: %ds', duration)

main()

Created with Raphaël 2.3.0 main args utils._data_transforms_cifar100 torchvision.datasets.CIFAR100 torch.utils.data.DataLoader torch.nn.CrossEntropyLoss Network optim.lr_scheduler.CosineAnnealingLR optim.Optimizer.step optim.lr_scheduler.CosineAnnealingLR.get_lr Network.update_p train infer utils.save Network.arch_parameters torch.nn.functional.softmax last stage? get_min_k_no_zero logging_switches parse_network check_sk_number delete_min_sk_prob keep_1_on keep_2_branches End get_min_k yes no
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    logging.info('GPU device = %d' % args.gpu)
    logging.info("args = %s", args)

没有将阶段内的处理封装为函数,流程不太直观。

_data_transforms_cifar100 包括随机截取、翻转、标准化和随机裁剪。
CIFAR100CIFAR10 的子类。
torch.utils.data.sampler.SubsetRandomSampler 从给定的索引列表中随机抽取元素样本,不替换。

    #  prepare dataset
    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=args.workers)

PRIMITIVES 定义了网络可用的原语,共8种。经3轮丢弃num_to_drop后,操作位置上剩1种或无操作。
switches_normalswitches_reduce为操作名称列表。单元内的连接数量为14。

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    switches = []
    for i in range(14):
        switches.append([True for j in range(len(PRIMITIVES))])
    switches_normal = copy.deepcopy(switches)
    switches_reduce = copy.deepcopy(switches)
    # To be moved to args
    num_to_keep = [5, 3, 1]
    num_to_drop = [3, 2, 2]
    if len(args.add_width) == 3:
        add_width = args.add_width
    else:
        add_width = [0, 0, 0]
    if len(args.add_layers) == 3:
        add_layers = args.add_layers
    else:
        add_layers = [0, 6, 12]
    if len(args.dropout_rate) ==3:
        drop_rate = args.dropout_rate
    else:
        drop_rate = [0.0, 0.0, 0.0]
    eps_no_archs = [10, 10, 10]

依次构建每个阶段的网络进行训练。sp即 search phase。
P-DARTS 网络深度为5->11->17,DARTS 为7。
Network 构建网络。
count_parameters_in_MB 统计模型大小。
train 传入两种优化器,搜索结构用 Adam,训练模型用 SGD。
最后5个 epoch 调用 infer 在验证集上测试模型。

    for sp in range(len(num_to_keep)):
        model = Network(args.init_channels + int(add_width[sp]), CIFAR_CLASSES, args.layers + int(add_layers[sp]), criterion, switches_normal=switches_normal, switches_reduce=switches_reduce, p=float(drop_rate[sp]))
        
        model = model.cuda()
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
        network_params = []
        for k, v in model.named_parameters():
            if not (k.endswith('alphas_normal') or k.endswith('alphas_reduce')):
                network_params.append(v)       
        optimizer = torch.optim.SGD(
                network_params,
                args.learning_rate,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
        optimizer_a = torch.optim.Adam(model.arch_parameters(),
                    lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, float(args.epochs), eta_min=args.learning_rate_min)
        sm_dim = -1
        epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2
        for epoch in range(epochs):
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training
            if epoch < eps_no_arch:
                model.p = float(drop_rate[sp]) * (epochs - epoch - 1) / epochs
                model.update_p()
                train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=False)
            else:
                model.p = float(drop_rate[sp]) * np.exp(-(epoch - eps_no_arch) * scale_factor) 
                model.update_p()                
                train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=True)
            logging.info('Train_acc %f', train_acc)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds', epoch_duration)
            # validation
            if epochs - epoch < 5:
                valid_acc, valid_obj = infer(valid_queue, model, criterion)
                logging.info('Valid_acc %f', valid_acc)

utils.save 保存阶段训练的结果。问题是名字一样会覆盖。
switches_normal_2switches_reduce_2为第2阶段处理前的操作列表。

        utils.save(model, os.path.join(args.save, 'weights.pt'))
        print('------Dropping %d paths------' % num_to_drop[sp])
        # Save switches info for s-c refinement. 
        if sp == len(num_to_keep) - 1:
            switches_normal_2 = copy.deepcopy(switches_normal)
            switches_reduce_2 = copy.deepcopy(switches_reduce)

arch_parameters 返回 ( α n o r m a l , α r e d u c e ) (\alpha_{normal}, \alpha_{reduce}) (αnormal,αreduce)
计算normal_prob
e x p ( α o ( i , j ) ) ∑ o ′ ∈ O e x p ( α o ′ ( i , j ) ) \begin{aligned} \frac{\mathrm{exp}(\alpha_o^{(i,j)})}{\sum_{o'\in\mathcal{O}}\mathrm{exp}(\alpha_{o'}^{(i,j)})} \end{aligned} oOexp(αo(i,j))exp(αo(i,j))
idxs记录处于活跃状态的操作符的类型索引。
get_min_k 返回最小的num_to_drop[sp]个索引。
get_min_k_no_zero 先检查idxs是否有0。

在最后一个阶段丢弃所有空操作,否则丢弃指定数量的小权重操作。

        # drop operations with low architecture weights
        arch_param = model.arch_parameters()
        normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()        
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_normal[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                # for the last stage, drop all Zero operations
                drop = get_min_k_no_zero(normal_prob[i, :], idxs, num_to_drop[sp])
            else:
                drop = get_min_k(normal_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_normal[i][idxs[idx]] = False

缩减单元的处理与之相同。

        reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_reduce[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(reduce_prob[i, :], idxs, num_to_drop[sp])
            else:
                drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_reduce[i][idxs[idx]] = False
        logging.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal)
        logging.info('switches_reduce = %s', switches_reduce)
        logging_switches(switches_reduce)

在阶段的末尾,读取结构参数。
normal_finalreduce_final记录每个单元中非空操作选中的最大概率。

        if sp == len(num_to_keep) - 1:
            arch_param = model.arch_parameters()
            normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
            reduce_prob = F.softmax(arch_param[1], dim=sm_dim).data.cpu().numpy()
            normal_final = [0 for idx in range(14)]
            reduce_final = [0 for idx in range(14)]
            # remove all Zero operations
            for i in range(14):
                if switches_normal_2[i][0] == True:
                    normal_prob[i][0] = 0
                normal_final[i] = max(normal_prob[i])
                if switches_reduce_2[i][0] == True:
                    reduce_prob[i][0] = 0
                reduce_final[i] = max(reduce_prob[i])       

单元中的第1层为两个操作,start = 2跳过。2-4,5-8,9-13。
tbsntbsr为标准和缩减单元当前层供选择的位置。根据操作概率的大小排序。keep_normalkeep_reduce记录需要保持的连接的索引。
过滤得到最终的switches_normalswitches_reduce,每层两个操作。

            # Generate Architecture, similar to DARTS
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(3):
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n), key=lambda x: tbsn[x])
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)
                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            # set switches according the ranking of arch parameters
            for i in range(14):
                if not i in keep_normal:
                    for j in range(len(PRIMITIVES)):
                        switches_normal[i][j] = False
                if not i in keep_reduce:
                    for j in range(len(PRIMITIVES)):
                        switches_reduce[i][j] = False

parse_network 根据编码列表解析得到网络基因型。
check_sk_number 检查网络标准单元中skip_connect的数量,对应 PRIMITIVES 的索引3。
delete_min_sk_prob 删除最小权重的跳跃连接。
keep_1_on 丢2留一。
keep_2_branches 修剪连接,每层仅保留两个。

逐渐减少网络标准单元中skip_connect的数量并记录。


            # translate switches into genotype
            genotype = parse_network(switches_normal, switches_reduce)
            logging.info(genotype)
            ## restrict skipconnect (normal cell only)
            logging.info('Restricting skipconnect...')
            # generating genotypes with different numbers of skip-connect operations
            for sks in range(0, 9):
                max_sk = 8 - sks                
                num_sk = check_sk_number(switches_normal)               
                if not num_sk > max_sk:
                    continue
                while num_sk > max_sk:
                    normal_prob = delete_min_sk_prob(switches_normal, switches_normal_2, normal_prob)
                    switches_normal = keep_1_on(switches_normal_2, normal_prob)
                    switches_normal = keep_2_branches(switches_normal, normal_prob)
                    num_sk = check_sk_number(switches_normal)
                logging.info('Number of skip-connect: %d', max_sk)
                genotype = parse_network(switches_normal, switches_reduce)
                logging.info(genotype)              

train

初始化3个指标。

    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()

如果训练结构,从valid_queue中取数据,先行训练。

    for step, (input, target) in enumerate(train_queue):
        model.train()
        n = input.size(0)
        input = input.cuda()
        target = target.cuda(non_blocking=True)
        if train_arch:
            # In the original implementation of DARTS, it is input_search, target_search = next(iter(valid_queue), which slows down
            # the training when using PyTorch 0.4 and above. 
            try:
                input_search, target_search = next(valid_queue_iter)
            except:
                valid_queue_iter = iter(valid_queue)
                input_search, target_search = next(valid_queue_iter)
            input_search = input_search.cuda()
            target_search = target_search.cuda(non_blocking=True)
            optimizer_a.zero_grad()
            logits = model(input_search)
            loss_a = criterion(logits, target_search)
            loss_a.backward()
            nn.utils.clip_grad_norm_(model.arch_parameters(), args.grad_clip)
            optimizer_a.step()

在训练集上训练权重。

        optimizer.zero_grad()
        logits = model(input)
        loss = criterion(logits, target)

        loss.backward()
        nn.utils.clip_grad_norm_(network_params, args.grad_clip)
        optimizer.step()

调用 utils.accuracy 计算训练集上的准确率。

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

        if step % args.report_freq == 0:
            logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f', step, objs.avg, top1.avg, top5.avg)

    return top1.avg, objs.avg

infer

Created with Raphaël 2.3.0 infer valid_queue nn.Module.eval Network nn.CrossEntropyLoss utils.accuracy objs, top1, top5 End
    objs = utils.AvgrageMeter()
    top1 = utils.AvgrageMeter()
    top5 = utils.AvgrageMeter()
    model.eval()

    for step, (input, target) in enumerate(valid_queue):
        input = input.cuda()
        target = target.cuda(non_blocking=True)
        with torch.no_grad():
            logits = model(input)
            loss = criterion(logits, target)

        prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)

        if step % args.report_freq == 0:
            logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

    return top1.avg, objs.avg

_data_transforms_cifar10

相比原有变换多了 Cutout

  CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
  CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

  train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
  ])
  if args.cutout:
    train_transform.transforms.append(Cutout(args.cutout_length))

  valid_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
  return train_transform, valid_transform

Cutout

    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img

count_parameters_in_MB

  return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6

parse_network

嵌套定义函数_parse_switches。解析两种类型的单元,记录操作类型和所在层次,得到 Genotype 类型的元组。

    def _parse_switches(switches):
        n = 2
        start = 0
        gene = []
        step = 4
        for i in range(step):
            end = start + n
            for j in range(start, end):
                for k in range(len(switches[j])):
                    if switches[j][k]:
                        gene.append((PRIMITIVES[k], j - start))
            start = end
            n = n + 1
        return gene
    gene_normal = _parse_switches(switches_normal)
    gene_reduce = _parse_switches(switches_reduce)
    
    concat = range(2, 6)
    
    genotype = Genotype(
        normal=gene_normal, normal_concat=concat, 
        reduce=gene_reduce, reduce_concat=concat
    )
    
    return genotype

Network

C为通道数量,layers为层数,steps为内部所划分的层次,multiplier为输出通道的乘数,stem_multiplier为柄通道乘数。
switch_ons记录每个操作位置可选操作的数量。self.switch_on直接取第一个位置的操作数。

    def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3, switches_normal=[], switches_reduce=[], p=0.0):
        super(Network, self).__init__()
        self._C = C
        self._num_classes = num_classes
        self._layers = layers
        self._criterion = criterion
        self._steps = steps
        self._multiplier = multiplier
        self.p = p
        self.switches_normal = switches_normal
        switch_ons = []
        for i in range(len(switches_normal)):
            ons = 0
            for j in range(len(switches_normal[i])):
                if switches_normal[i][j]:
                    ons = ons + 1
            switch_ons.append(ons)
            ons = 0
        self.switch_on = switch_ons[0]

网络起始未下采样,在1/3和2/3处插入缩减单元。

        C_curr = stem_multiplier*C
        self.stem = nn.Sequential(
            nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
            nn.BatchNorm2d(C_curr)
        )
    
        C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
        self.cells = nn.ModuleList()
        reduction_prev = False
        for i in range(layers):
            if i in [layers//3, 2*layers//3]:
                C_curr *= 2
                reduction = True
                cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches_reduce, self.p)
            else:
                reduction = False
                cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches_normal, self.p)
#            cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, switches)
            reduction_prev = reduction
            self.cells += [cell]
            C_prev_prev, C_prev = C_prev, multiplier*C_curr

_initialize_alphas 初始化结构参数,类型为Variable,而不是 torch.nn.Parameter

        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(C_prev, num_classes)

        self._initialize_alphas()

forward

同类型的不同单元公用结构参数。

        s0 = s1 = self.stem(input)
        for i, cell in enumerate(self.cells):
            if cell.reduction:
                if self.alphas_reduce.size(1) == 1:
                    weights = F.softmax(self.alphas_reduce, dim=0)
                else:
                    weights = F.softmax(self.alphas_reduce, dim=-1)
            else:
                if self.alphas_normal.size(1) == 1:
                    weights = F.softmax(self.alphas_normal, dim=0)
                else:
                    weights = F.softmax(self.alphas_normal, dim=-1)
            s0, s1 = s1, cell(s0, s1, weights)
        out = self.global_pooling(s1)
        logits = self.classifier(out.view(out.size(0),-1))
        return logits

update_p

update_p 给数据并行带来了麻烦。

        for cell in self.cells:
            cell.p = self.p
            cell.update_p()

_loss

函数没有用到。

        logits = self(input)
        return self._criterion(logits, target) 

_initialize_alphas

k为单元中 MixedOp 的数量,self.switch_onMixedOp 中候选操作的种类。

        k = sum(1 for i in range(self._steps) for n in range(2+i))
        num_ops = self.switch_on
        self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
        self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
        self._arch_parameters = [
            self.alphas_normal,
            self.alphas_reduce,
        ]

arch_parameters

        return self._arch_parameters

Cell

preprocess0
MixedOp0
preprocess1
MixedOp1
add0
MixedOp2
MixedOp3
MixedOp4
add1
MixedOp5
MixedOp6
MixedOp7
MixedOp8
add2
MixedOp9
MixedOp10
MixedOp11
MixedOp12
MixedOp13
add3
concatenate

FactorizedReduce 采用位置交错的两组卷积。
NASNetAmoebaNetPNAS 一样卷积采用 ReLUConvBN

没有手动初始化权重。

steps=4,使得 Cell 中包含 2+3+4+5=14 个 MixedOp,即len(self.cell_ops)=14。每层多2个用于处理输入。

    def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, switches, p):
        super(Cell, self).__init__()
        self.reduction = reduction
        self.p = p
        if reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
        self._steps = steps
        self._multiplier = multiplier

        self.cell_ops = nn.ModuleList()
        switch_count = 0
        for i in range(self._steps):
            for j in range(2+i):
                stride = 2 if reduction and j < 2 else 1
                op = MixedOp(C, stride, switch=switches[switch_count], p=self.p)
                self.cell_ops.append(op)
                switch_count = switch_count + 1

update_p

        for op in self.cell_ops:
            op.p = self.p
            op.update_p()

forward

每个中间节点都基于其所有先前节点计算:

x ( j ) = ∑ i < j o ( i , j ) ( x ( i ) ) \begin{aligned} x^{(j)} = \sum_{i<j} o^{(i, j)}(x^{(i)}) \end{aligned} x(j)=i<jo(i,j)(x(i))

还包括一个特殊的 z e r o \mathit{zero} zero 操作,表示两个节点之间缺少连接。 因此,学习单元的任务减少了学习其边缘的操作。

对于每一步,累加所有操作的输出。offset不断累加意味着self.cell_ops的数量为2+3+4+5=14。

        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)
        states = [s0, s1]
        offset = 0
        for i in range(self._steps):
            s = sum(self.cell_ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
            offset += len(states)
            states.append(s)

        return torch.cat(states[-self._multiplier:], dim=1)

MixedOp

OPS 为操作字典。affine=False设置 nn.BatchNorm2d 屏蔽可学习参数,等效于 Caffe 中的 BN 层。

DARTS 的 A.1.1 中指出由于架构在整个搜索过程中会有所不同,因此其始终使用批量特定的统计信息进行批量标准化而不是全局移动平均值。在搜索过程中禁用所有批量标准化中可学习的仿射参数,以避免重新调整候选操作的输出。然而,代码中并未设置track_running_stats=False

switch为操作的掩码,len(switch)=len(PRIMITIVES)PRIMITIVES 共有8种操作,存储到self.m_ops

    def __init__(self, C, stride, switch, p):
        super(MixedOp, self).__init__()
        self.m_ops = nn.ModuleList()
        self.p = p
        for i in range(len(switch)):
            if switch[i]:
                primitive = PRIMITIVES[i]
                op = OPS[primitive](C, stride, False)
                if 'pool' in primitive:
                    op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
                if isinstance(op, Identity) and p > 0:
                    op = nn.Sequential(op, nn.Dropout(self.p))
                self.m_ops.append(op)

update_p

如果第一个操作是Identity,则在后面添加操作。

        for op in self.m_ops:
            if isinstance(op, nn.Sequential):
                if isinstance(op[0], Identity):
                    op[1].p = self.p

forward

O \mathcal{O} O 为一组候选操作(例如卷积、最大合并、 z e r o \mathit{zero} zero),其中每个操作代表应用于 x ( i ) x^{(i)} x(i) 的函数 o ( ⋅ ) o(\cdot) o()

为了使搜索空间连续,DARTS 将特定操作的分类选择放宽为所有可能操作的 softmax:
o ˉ ( i , j ) ( x ) = ∑ o ∈ O exp ⁡ ( α o ( i , j ) ) ∑ o ′ ∈ O exp ⁡ ( α o ′ ( i , j ) ) o ( x ) \begin{aligned} \bar{o}^{(i,j)}(x) = \sum_{o \in \mathcal{O}} \frac{\exp(\alpha_o^{(i,j)})}{\sum_{o' \in \mathcal{O}} \exp(\alpha_{o'}^{(i,j)})} o(x) \end{aligned} oˉ(i,j)(x)=oOoOexp(αo(i,j))exp(αo(i,j))o(x)
其中一对节点 ( i , j ) (i,j) (i,j) 的操作混合权重由维数 ∣ O ∣ |\mathcal{O}| O 的向量 α ( i , j ) \alpha^{(i,j)} α(i,j) 参数化。

然后,架构搜索的任务化简为学习一组连续变量 α = { α ( i , j ) } \alpha = \big\{ \alpha^{(i,j)} \big\} α={α(i,j)}。在搜索结束时,可以通过用最可能的操作替换每个混合操作 o ˉ ( i , j ) \bar{o}^{(i,j)} oˉ(i,j) 来获得离散体系结构,即
o ( i , j ) = a r g m a x o ∈ O   α o ( i , j ) o^{(i,j)} = \mathrm{argmax}_{o \in \mathcal{O}} \, \alpha^{(i,j)}_o o(i,j)=argmaxoOαo(i,j).

        return sum(w * op(x) for w, op in zip(weights, self.m_ops))

模型中定义forward之外的函数,导致不能正常使用 torch.nn.DataParallel

delete_min_sk_prob

嵌套定义_get_sk_idx函数。如果输入的列表里没有跳跃连接则返回-1;否则返回原列表switches_bk中的跳跃连接索引。

    def _get_sk_idx(switches_in, switches_bk, k):
        if not switches_in[k][3]:
            idx = -1
        else:
            idx = 0
            for i in range(3):
                if switches_bk[k][i]:
                    idx = idx + 1
        return idx

避免修改输入,sk_prob记录每个位置上跳跃连接的权重。从中取最小的置为0。

    probs_out = copy.deepcopy(probs_in)
    sk_prob = [1.0 for i in range(len(switches_bk))]
    for i in range(len(switches_in)):
        idx = _get_sk_idx(switches_in, switches_bk, i)
        if not idx == -1:
            sk_prob[i] = probs_out[i][idx]
    d_idx = np.argmin(sk_prob)
    idx = _get_sk_idx(switches_in, switches_bk, d_idx)
    probs_out[d_idx][idx] = 0.0
    
    return probs_out

keep_1_on

keep_1_on
get_min_k_no_zero

对于每个操作位,idxs记录可选操作的索引。get_min_k_no_zero 查找操作位概率最小且非空的2个操作,丢弃掉。

    switches = copy.deepcopy(switches_in)
    for i in range(len(switches)):
        idxs = []
        for j in range(len(PRIMITIVES)):
            if switches[i][j]:
                idxs.append(j)
        drop = get_min_k_no_zero(probs[i, :], idxs, 2)
        for idx in drop:
            switches[i][idxs[idx]] = False            
    return switches

keep_2_branches

final_prob为每个操作位上操作最大概率。

    switches = copy.deepcopy(switches_in)
    final_prob = [0.0 for i in range(len(switches))]
    for i in range(len(switches)):
        final_prob[i] = max(probs[i])

第1层只有两个操作位,所以直接保留。
后续3层依次取出其最大概率,排序后取最大的两个位置。

    keep = [0, 1]
    n = 3
    start = 2
    for i in range(3):
        end = start + n
        tb = final_prob[start:end]
        edge = sorted(range(n), key=lambda x: tb[x])
        keep.append(edge[-1] + start)
        keep.append(edge[-2] + start)
        start = end
        n = n + 1

遍历位置,在switches屏蔽未选中的位置。

    for i in range(len(switches)):
        if not i in keep:
            for j in range(len(PRIMITIVES)):
                switches[i][j] = False  
    return switches  

logging_switches

    for i in range(len(switches)):
        ops = []
        for j in range(len(switches[i])):
            if switches[i][j]:
                ops.append(PRIMITIVES[j])
        logging.info(ops)

参考资料:

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值