【论文阅读笔记】darts代码和论文结合阅读


参考:https://zhuanlan.zhihu.com/p/73037439
注意:本篇都是分析的CNN部分,没有对RNN部分解读。
本篇文章主要通过代码对论文进行解读,darts就是对构成网络的cell的结构进行自动搜索,然后再将搜索到的cell 连接成一个网络。

introduction

differentiable architecture search

2.1 search space

darts如何对一个cell进行搜索的呢,我们通过下图figure 1了解darts的基本思想:
(a)这些灰色的小方块都是一个cell内的nodes,我们需要通过一些操作(如池化、卷积)把这些nodes连起来
(b)原本一个个操作都是离散的,我们为了实现可微分的搜索,也就是为了使搜索空间连续,我们将特定操作的确定的选择放宽到所有可能操作上的softmax,也就是我们给两个block之间的全部操作都赋予权重。假设我们有三个操作,我们把每个节点都通过上述方法和它所有的前驱节点相连,就得到了下图(b)
©然后我们就通过梯度下降对权重进行优化,最后对每个节点取argmax也就是哪个操作的α值最大,就选这个操作。
(d)选了最大的α后的操作后,我们就得到了(d)的路径
在这里插入图片描述
具体的在CIFAR-10定义网络结构我们可以看下图:
在这里插入图片描述
一个Network是由8个cell组成的,cell分为reduction cell 和normal cell两种,在网络的三分之一处和三分之二处是reduction cell,其它是normal cell。reduction cell共享权重 α r e d u t i o n \alpha_{redution} αredution,normal cell共享权重 α n o r m a l \alpha_{normal} αnormal
一个cell由7个nodes组成,分别是2个input nodes,4个intermediate nodes和1个output nodes。

  • input nodes:是前两层cell的输出,input node 0是cell k-2的输出,input node 1 是cell k-1的输出
  • intermediate nodes:和它所有的前驱节点相连,具体看下面的公式
    对于节点 x ( j ) x^{(j)} x(j),通过操作o和它所有的前驱节点i相连,那么如何对操作o进行continuous relaxation,具体看2.2节
    在这里插入图片描述
  • output nodes:四个中间节点intermediate nodes concat,这个concat是对通道concat的,也就是原来输入的通道是C,输出以后变成了4C

model_search.py Class Cell

具体Cell是怎么定义的我们通过代码来看

class Cell(nn.Module):

  def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
    super(Cell, self).__init__()
    self.reduction = reduction
    #input nodes的结构固定不变,不参与搜索
    #决定第一个input nodes的结构,取决于前一个cell是否是reduction
    if reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)#第一个input_nodes是cell k-2的输出,cell k-2的输出通道数为C_prev_prev,所以这里操作的输入通道数为C_prev_prev
    #第二个input nodes的结构
    self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)#第二个input_nodes是cell k-1的输出
    self._steps = steps # 每个cell中有4个节点的连接状态待确定
    self._multiplier = multiplier

    self._ops = nn.ModuleList() # 构建operation的modulelist
    self._bns = nn.ModuleList()
    #遍历4个intermediate nodes构建混合操作
    for i in range(self._steps):
      #遍历当前结点i的所有前驱节点
      for j in range(2+i): #对第i个节点来说,他有j个前驱节点(每个节点的input都由前两个cell的输出和当前cell的前面的节点组成)
        stride = 2 if reduction and j < 2 else 1
        op = MixedOp(C, stride) #op是构建两个节点之间的混合
        self._ops.append(op)#所有边的混合操作添加到ops,list的len为2+3+4+5=14[[],[],...,[]]


  #cell中的计算过程,前向传播时自动调用
  def forward(self, s0, s1, weights):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)

    states = [s0, s1] #当前节点的前驱节点
    offset = 0
    #遍历每个intermediate nodes,得到每个节点的output
    for i in range(self._steps):
      s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))  #s为当前节点i的output,在ops找到i对应的操作,然后对i的所有前驱节点做相应的操作(调用了MixedOp的forward),然后把结果相加
      offset += len(states)
      states.append(s)#把当前节点i的output作为下一个节点的输入
      #states中为[s0,s1,b1,b2,b3,b4] b1,b2,b3,b4分别是四个intermediate output的输出
    return torch.cat(states[-self._multiplier:], dim=1)#对intermediate的output进行concat作为当前cell的输出
                                                       #dim=1是指对通道这个维度concat,所以输出的通道数变成原来的4倍

2.2 continuous relaxation and optimization

为了使搜索空间连续,我们为每个操作都赋予一个权重 α \alpha α,然后做softmax。这样搜索任务就简化为学习权重 α \alpha α
在这里插入图片描述
搜索完成后,我们通过argmax选权重最大的操作,这样就又得到了离散的结构,具体如下:
o ( i , j ) = a r g m a x o ∈ O α 0 ( i , j ) o^{(i,j)=argmax_{o∈O}\alpha_0^{(i,j)}} o(i,j)=argmaxoOα0(i,j)
argmax(f(x))是使得 f(x)取得最大值所对应的变量点x(或x的集合),
也就是哪个操作对应的alpha取值最大,就取哪个操作.

model_search.py Class MixedOp

具体操作是如何混合的我们通过代码来看

class MixedOp(nn.Module):

  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in PRIMITIVES:  #PRIMITIVES中就是8个操作
      op = OPS[primitive](C, stride, False)#OPS中存储了各种操作的函数
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False)) #给池化操作后面加一个batchnormalization
      self._ops.append(op)#把这些op都放在预先定义好的modulelist里

  def forward(self, x, weights):
    return sum(w * op(x) for w, op in zip(weights, self._ops))  #op(x)就是对输入x做一个相应的操作 w1*op1(x)+w2*op2(x)+...+w8*op8(x)
                                                                #也就是对输入x做8个操作并乘以相应的权重,把结果加起来

After relaxation, our goal is to jointly learn the architecture α and the weights w within all the mixed operations (e.g. weights of the convolution filters). Analogous to architecture search using RL or evolution where the validation set performance is treated as the reward or fitness, DARTS aims to optimize the validation loss, but using gradient descent.

在对操作relaxation之后,我们就要对 α \alpha α和w进行学习,Darts是通过梯度下降优化validation loss来学习权重的。

Denote by L t r a i n L_{train} Ltrain and L v a l L_{val} Lval the training and the validation loss, respectively. Both losses are determined not only by the architecture α, but also the weights w in the network.
The goal for architecture search is to find α ∗ α^∗ α that minimizes the validation loss L v a l ( w ∗ , α ∗ ) L_{val}(w^∗ , α^∗ ) Lval(w,α), where the weights w ∗ w^∗ w associated with the architecture are obtained by minimizing the training loss w ∗ w^∗ w = a r g m i n w L t r a i n ( w , α ∗ ) argmin_wL_{train}(w, α^∗ ) argmin

  • 94
    点赞
  • 203
    收藏
    觉得还不错? 一键收藏
  • 53
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 53
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值