pt.darts源码分析

解释博客:

Intuitive Explanation of Differentiable Architecture Search (DARTS)

SearchCNN 构建

SearchCNNController

models.search_cnn.SearchCNNController

n_ops = 8, n_nodes = 4

之所有有个i + 2是因为include 2 input nodes, 一个是上一个输入,一个是跨层连接的上上个输入。

for i in range(n_nodes):
    self.alpha_normal.append(nn.Parameter(1e-3*torch.randn
    (i+2, n_ops)))
    self.alpha_reduce.append(nn.Parameter(1e-3*torch.randn
    (i+2, n_ops)))
Variable Value
C_in 3
C 16
n_classes 10
n_layers 8
n_nodes 4
stem_multiplier 3

SearchCNNController构造函数的最后:

self.net = SearchCNN(C_in, C, n_classes, n_layers, n_nodes, stem_multiplier)

SearchCNN

每个cell有两个输入,一个是上一个输入,一个是跨层连接的上上个输入。

刚开始c_pp = c_p =48c_cur = 16

通过跨层连接将cells组织起来:

在这里插入图片描述

SearchCell中nodes之间DAG的构建

models.search_cells.SearchCell#__init__

op = ops.MixedOp(C, stride)
self.dag[i].append(op)

看到MixedOp内部

self._ops = nn.ModuleList()
for primitive in gt.PRIMITIVES:
    op = OPS[primitive](C, stride, affine=False)
    self._ops.append(op)

# genotypes.py
PRIMITIVES = [
    'max_pool_3x3',
    'avg_pool_3x3',
    'skip_connect', # identity
    'sep_conv_3x3',
    'sep_conv_5x5',
    'dil_conv_3x3',
    'dil_conv_5x5',
    'none'
]

models.search_cells.SearchCell#forward

在这里插入图片描述

models/search_cells.py:53

s.shape
Out[2]: torch.Size([
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值