解释博客:
Intuitive Explanation of Differentiable Architecture Search (DARTS)
SearchCNN 构建
SearchCNNController
models.search_cnn.SearchCNNController
n_ops = 8
, n_nodes = 4
之所有有个i + 2
是因为include 2 input nodes
, 一个是上一个输入,一个是跨层连接的上上个输入。
for i in range(n_nodes):
self.alpha_normal.append(nn.Parameter(1e-3*torch.randn
(i+2, n_ops)))
self.alpha_reduce.append(nn.Parameter(1e-3*torch.randn
(i+2, n_ops)))
Variable | Value |
---|---|
C_in | 3 |
C | 16 |
n_classes | 10 |
n_layers | 8 |
n_nodes | 4 |
stem_multiplier | 3 |
SearchCNNController
构造函数的最后:
self.net = SearchCNN(C_in, C, n_classes, n_layers, n_nodes, stem_multiplier)
SearchCNN
每个cell有两个输入,一个是上一个输入,一个是跨层连接的上上个输入。
刚开始c_pp = c_p =48
,c_cur = 16
通过跨层连接将cells
组织起来:
SearchCell中nodes之间DAG的构建
models.search_cells.SearchCell#__init__
op = ops.MixedOp(C, stride)
self.dag[i].append(op)
看到MixedOp
内部
self._ops = nn.ModuleList()
for primitive in gt.PRIMITIVES:
op = OPS[primitive](C, stride, affine=False)
self._ops.append(op)
# genotypes.py
PRIMITIVES = [
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect', # identity
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'none'
]
models.search_cells.SearchCell#forward
models/search_cells.py:53
s.shape
Out[2]: torch.Size([