Progressive Differentiable Architecture Search Code Review

PDARTS Code Review

最近在看DARTS相关内容,这里看到了PDARTS这篇文章,对代码进行了相关review,在这里是记录也希望与大家分享,有问题的地方欢迎指摘

PDARTS 这篇文章是在前人文章DARTS的基础上进行改进得到的,处理了DARTS中存在的一些问题,并取得了更好的实验效果

Problem: DARTS网络结构从简单数据集(e.g.CIFAR10 et al.)搜索得到,在解决复杂数据集时泛化能力差

PDARTS solution: Progressive(P ) 地加深网络的搜索深度,一定程度上解决了depth gap的问题,5cells->11cells->17cells

PDARTS Paper(中科院源):Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation
Code:https://github.com/chenxin061/pdarts.

不得不说华为的行动力确实强,DARTS获奖的同年(2019 ICLR),华为就凭PDARTS获得了2019 ICCV。言归正传,下面是code review的相关内容。

coding tips
copy.copy&copy.deepcopy浅拷贝和深拷贝
next()返回迭代对象的下一个值
sorted(range(n), *key*=lambda *x*: tbsn[x])返回tbsn从小到大排序的索引

1. train_search.py

1.1 prepare dataset

VariableValue
datasetCIFAR10orCIFAR100
args.train_portion确定训练集的比例

1.2 Preparation

VariableValue
14一个cell中edge的总数为14
lossnn.CrossEntropyLoss()
args.train_portion确定训练集的比例
switchescell的每个edge中哪些operation加入search,值为TrueFalse(8*14)
PRIMITIVES最基础的8个operation

1.3 Train

1.3.1 Model
Class Network
C当前stage每个cell的channel, default:16
layersCells的数量 5->11->17
stepsNode的数量 default:4
stem_multiplier输入图片的第一次处理
multiplier每个cell处理时候输出channel的翻倍数量
pdrop rate,随着epoch变化
weightalpha经过softmax之后的值
self.stem即C0前面的网络结构
reduction若reduction,则当前cell的输出size减半
Cell构建cell内部的网络
self._arch_parametersself.alphas_normal&self.alphas_reduce (14*switch_on)
class Network(nn.Module):
    def forward(self, input):
        s0 = s1 = self.stem(input)      # s0,Ck-2; s1,Ck-1
        for i, cell in enumerate(self.cells):
            if cell.reduction:
                if self.alphas_reduce.size(1) == 1: 
                    weights = F.softmax(self.alphas_reduce, dim=0)
                else:
                    weights = F.softmax(self.alphas_reduce, dim=-1)
            else:
                if self.alphas_normal.size(1) == 1:
                    weights = F.softmax(self.alphas_normal, dim=0)
                else:
                    weights = F.softmax(self.alphas_normal, dim=-1)
            s0, s1 = s1, cell(s0, s1, weights)
        out = self.global_pooling(s1)
        logits = self.classifier(out.view(out.size(0),-1))
        return logits
Class Cell
self.preprocess0FactorizedReduce: Ck-2对Reduction Cell的操作,经过两次channel减半的卷积操作,之后结果contact在一起
ReLUConvBN: 非Reduction Cel的处理
MixedOp构成两个node之间所有操作的混合
switch_count0-13, 对所有edge的编号
OPS名称与具体操作的键值对
PRIMITIVES操作的名称
class FactorizedReduce(nn.Module):
  def __init__(self, C_in, C_out, affine=True):
    super(FactorizedReduce, self).__init__()
    assert C_out % 2 == 0
    self.relu = nn.ReLU(inplace=False)
    self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
    self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 
    self.bn = nn.BatchNorm2d(C_out, affine=affine)
class Cell(nn.Module):
    def forward(self, s0, s1, weights):
        s0 = self.preprocess0(s0)		# 对来自前两个cell的输出进行预处理
        s1 = self.preprocess1(s1)
        states = [s0, s1]			    # s0,s1,n0,n1,n2,n3
        offset = 0
        for i in range(self._steps):
            # 对node[i],将来自前面所有node的结果相加,weight调整不同edge输出的占比
            s = sum(self.cell_ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
            offset += len(states)
            states.append(s)
        return torch.cat(states[-self._multiplier:], dim=1)    
for i in range(layers):
    # 1/3和2/3处的cell为reduction cell
    if i in [layers//3, 2*layers//3]:
        C_curr *= 2
        reduction = True
        cell = Cell(steps, multiplier, C_prev_prev, C_prev, 
                    C_curr, reduction, reduction_prev, switches_reduce, self.p)
    else:
        reduction = False
        cell = Cell(steps, multiplier, C_prev_prev, C_prev, 
                    C_curr, reduction, reduction_prev, switches_normal, self.p)
1.3.2 Train process
Pipeline
sp目前处于哪个stage,取值:0,1,2
Train Model
Epochdefault:25;
1-10 为train,不更新α, 仅用训练集更新w;
11-25 为train,更新w,此外用验证集更新α;
21-25 做validation;
Drop Operations该stage的Epoch都结束后
get_min_k获取当前edge α最小的operation
get_min_k_no_zerolast stage,先把zero operation去除,再去除α最小的operation
Last Stage
normal_final & reduce_final存储cell中的每个edge应该选择的operation的idx
"""
Generate Architecture
选择权重最大的两个前驱节点,再从这两个节点里确定应该选哪个edge
"""
for i in range(3):
    end = start + n     					# 每个node的前驱edge,n0:0,1;  n2:2,3,4;...
    tbsn = normal_final[start:end]          # n*num_ops,normal_final和reduce_final都是排过序的
    tbsr = reduce_final[start:end]
    edge_n = sorted(range(n), key=lambda x: tbsn[x])
    keep_normal.append(edge_n[-1] + start)
    keep_normal.append(edge_n[-2] + start)
    edge_r = sorted(range(n), key=lambda x: tbsr[x])
    keep_reduce.append(edge_r[-1] + start)
    keep_reduce.append(edge_r[-2] + start)
    start = end
    n = n + 1
"""
translate switches into genotype
"""
def parse_network(switches_normal, switches_reduce):

    def _parse_switches(switches):
        n = 2
        start = 0
        gene = []
        step = 4
        for i in range(step):
            end = start + n
            for j in range(start, end):
                for k in range(len(switches[j])):
                    if switches[j][k]:
                        gene.append((PRIMITIVES[k], j - start))
            start = end
            n = n + 1
        return gene
    gene_normal = _parse_switches(switches_normal)
    gene_reduce = _parse_switches(switches_reduce)
    
    concat = range(2, 6)
    
    genotype = Genotype(
        normal=gene_normal, normal_concat=concat, 
        reduce=gene_reduce, reduce_concat=concat
    )
    
    return genotype		# 把(操作,前驱节点序号)放到list gene中,[('sep_conv_3x3', 1),...,]

"""
E.G.
"""
DARTS_V1 = Genotype(normal=[('sep_conv_3x3', 1), # node0的前驱node以及相关operation
                            ('sep_conv_3x3', 0), # node0的前驱node以及相关operation
                    		('skip_connect', 0), # node1的前驱node以及相关operation
                   	 		('sep_conv_3x3', 1), # node1的前驱node以及相关operation
                    		('skip_connect', 0), # node2的前驱node以及相关operation
                    		('sep_conv_3x3', 1), # node2的前驱node以及相关operation
                    		('sep_conv_3x3', 0), # node3的前驱node以及相关operation
                    		('skip_connect', 2)],# node3的前驱node以及相关operation
                    		normal_concat=[2, 3, 4, 5], # 最终将结果concat在一起
                    
              		reduce=[('max_pool_3x3', 0), 
                    		('max_pool_3x3', 1), 
                    		('skip_connect', 2), 
                    		('max_pool_3x3', 0), 
                    		('max_pool_3x3', 0), 
                    		('skip_connect', 2), 
                    		('skip_connect', 2), 
                    		('avg_pool_3x3', 0)], reduce_concat=[2, 3, 4, 5])

在这里插入图片描述
0: Ck-2;
1: Ck-1;
2: node0;
3: node1;
4: node2;
5: node3

参考资料

PDARTS 网络结构搜索程序分析
【论文阅读笔记】darts代码和论文结合阅读
【ICLR2019】DARTS代码解读

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值