PDARTS Code Review
最近在看DARTS相关内容,这里看到了PDARTS这篇文章,对代码进行了相关review,在这里是记录也希望与大家分享,有问题的地方欢迎指摘
PDARTS 这篇文章是在前人文章DARTS的基础上进行改进得到的,处理了DARTS中存在的一些问题,并取得了更好的实验效果
Problem: DARTS网络结构从简单数据集(e.g.CIFAR10 et al.)搜索得到,在解决复杂数据集时泛化能力差
PDARTS solution: Progressive(P ) 地加深网络的搜索深度,一定程度上解决了depth gap的问题,5cells->11cells->17cells
PDARTS Paper(中科院源):Progressive Differentiable Architecture Search: Bridging the Depth Gap between Search and Evaluation
Code:https://github.com/chenxin061/pdarts.
不得不说华为的行动力确实强,DARTS获奖的同年(2019 ICLR),华为就凭PDARTS获得了2019 ICCV。言归正传,下面是code review的相关内容。
coding tips | |
---|---|
copy.copy ©.deepcopy | 浅拷贝和深拷贝 |
next() | 返回迭代对象的下一个值 |
sorted(range(n), *key*=lambda *x*: tbsn[x]) | 返回tbsn从小到大排序的索引 |
1. train_search.py
1.1 prepare dataset
Variable | Value |
---|---|
dataset | CIFAR10 orCIFAR100 |
args.train_portion | 确定训练集的比例 |
1.2 Preparation
Variable | Value |
---|---|
14 | 一个cell中edge的总数为14 |
loss | nn.CrossEntropyLoss() |
args.train_portion | 确定训练集的比例 |
switches | cell的每个edge中哪些operation加入search,值为True 和False (8*14) |
PRIMITIVES | 最基础的8个operation |
1.3 Train
1.3.1 Model
Class Network | |
---|---|
C | 当前stage每个cell的channel, default:16 |
layers | Cells 的数量 5->11->17 |
steps | Node 的数量 default:4 |
stem_multiplier | 输入图片的第一次处理 |
multiplier | 每个cell处理时候输出channel的翻倍数量 |
p | drop rate,随着epoch变化 |
weight | alpha经过softmax之后的值 |
self.stem | 即C0前面的网络结构 |
reduction | 若reduction,则当前cell的输出size减半 |
Cell | 构建cell内部的网络 |
self._arch_parameters | self.alphas_normal &self.alphas_reduce (14*switch_on) |
class Network(nn.Module):
def forward(self, input):
s0 = s1 = self.stem(input) # s0,Ck-2; s1,Ck-1
for i, cell in enumerate(self.cells):
if cell.reduction:
if self.alphas_reduce.size(1) == 1:
weights = F.softmax(self.alphas_reduce, dim=0)
else:
weights = F.softmax(self.alphas_reduce, dim=-1)
else:
if self.alphas_normal.size(1) == 1:
weights = F.softmax(self.alphas_normal, dim=0)
else:
weights = F.softmax(self.alphas_normal, dim=-1)
s0, s1 = s1, cell(s0, s1, weights)
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0),-1))
return logits
Class Cell | |
---|---|
self.preprocess0 | FactorizedReduce : Ck-2对Reduction Cell的操作,经过两次channel减半的卷积操作,之后结果contact在一起ReLUConvBN : 非Reduction Cel的处理 |
MixedOp | 构成两个node之间所有操作的混合 |
switch_count | 0-13, 对所有edge的编号 |
OPS | 名称与具体操作的键值对 |
PRIMITIVES | 操作的名称 |
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
assert C_out % 2 == 0
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
class Cell(nn.Module):
def forward(self, s0, s1, weights):
s0 = self.preprocess0(s0) # 对来自前两个cell的输出进行预处理
s1 = self.preprocess1(s1)
states = [s0, s1] # s0,s1,n0,n1,n2,n3
offset = 0
for i in range(self._steps):
# 对node[i],将来自前面所有node的结果相加,weight调整不同edge输出的占比
s = sum(self.cell_ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
offset += len(states)
states.append(s)
return torch.cat(states[-self._multiplier:], dim=1)
for i in range(layers):
# 1/3和2/3处的cell为reduction cell
if i in [layers//3, 2*layers//3]:
C_curr *= 2
reduction = True
cell = Cell(steps, multiplier, C_prev_prev, C_prev,
C_curr, reduction, reduction_prev, switches_reduce, self.p)
else:
reduction = False
cell = Cell(steps, multiplier, C_prev_prev, C_prev,
C_curr, reduction, reduction_prev, switches_normal, self.p)
1.3.2 Train process
Pipeline | |
---|---|
sp | 目前处于哪个stage,取值:0,1,2 |
Train Model | |
Epoch | default:25; 1-10 为train,不更新α, 仅用训练集更新w; 11-25 为train,更新w,此外用验证集更新α; 21-25 做validation; |
Drop Operations | 该stage的Epoch都结束后 |
get_min_k | 获取当前edge α最小的operation |
get_min_k_no_zero | last stage,先把zero operation去除,再去除α最小的operation |
Last Stage | |
normal_final & reduce_final | 存储cell中的每个edge应该选择的operation的idx |
"""
Generate Architecture
选择权重最大的两个前驱节点,再从这两个节点里确定应该选哪个edge
"""
for i in range(3):
end = start + n # 每个node的前驱edge,n0:0,1; n2:2,3,4;...
tbsn = normal_final[start:end] # n*num_ops,normal_final和reduce_final都是排过序的
tbsr = reduce_final[start:end]
edge_n = sorted(range(n), key=lambda x: tbsn[x])
keep_normal.append(edge_n[-1] + start)
keep_normal.append(edge_n[-2] + start)
edge_r = sorted(range(n), key=lambda x: tbsr[x])
keep_reduce.append(edge_r[-1] + start)
keep_reduce.append(edge_r[-2] + start)
start = end
n = n + 1
"""
translate switches into genotype
"""
def parse_network(switches_normal, switches_reduce):
def _parse_switches(switches):
n = 2
start = 0
gene = []
step = 4
for i in range(step):
end = start + n
for j in range(start, end):
for k in range(len(switches[j])):
if switches[j][k]:
gene.append((PRIMITIVES[k], j - start))
start = end
n = n + 1
return gene
gene_normal = _parse_switches(switches_normal)
gene_reduce = _parse_switches(switches_reduce)
concat = range(2, 6)
genotype = Genotype(
normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat
)
return genotype # 把(操作,前驱节点序号)放到list gene中,[('sep_conv_3x3', 1),...,]
"""
E.G.
"""
DARTS_V1 = Genotype(normal=[('sep_conv_3x3', 1), # node0的前驱node以及相关operation
('sep_conv_3x3', 0), # node0的前驱node以及相关operation
('skip_connect', 0), # node1的前驱node以及相关operation
('sep_conv_3x3', 1), # node1的前驱node以及相关operation
('skip_connect', 0), # node2的前驱node以及相关operation
('sep_conv_3x3', 1), # node2的前驱node以及相关operation
('sep_conv_3x3', 0), # node3的前驱node以及相关operation
('skip_connect', 2)],# node3的前驱node以及相关operation
normal_concat=[2, 3, 4, 5], # 最终将结果concat在一起
reduce=[('max_pool_3x3', 0),
('max_pool_3x3', 1),
('skip_connect', 2),
('max_pool_3x3', 0),
('max_pool_3x3', 0),
('skip_connect', 2),
('skip_connect', 2),
('avg_pool_3x3', 0)], reduce_concat=[2, 3, 4, 5])
0: Ck-2;
1: Ck-1;
2: node0;
3: node1;
4: node2;
5: node3
参考资料
PDARTS 网络结构搜索程序分析
【论文阅读笔记】darts代码和论文结合阅读
【ICLR2019】DARTS代码解读