PC-DARTS代码笔记
参考代码:https://github.com/yuhuixu1993/PC-DARTS
PC-DARTS网络的构造:train_search_imagenet.py
先看一些网络的参数,后在回来看这些参数的作用:
网络的开始先通过三个卷积层把通道扩充到C_curr = stem_multiplier*C=48
然后是cells堆叠构成多层网络。
Cell的代码:
DARTS每个cell有两个输入节点和一个输出节点,对于卷积单元来说,输入节点被定义为前两层的单元输出;当reduction_prev=True时,对特征图C_prev_prev进行下采样,可能作者是为了不影响感受野,用1x1卷积进行下采样,实现很微妙:
如果reduction_prev=False,则连接:
根据DARTS论文描述,**每个中间节点都是基于所有它之前的节点进行计算的,**对于代码为:
MixedOp的实现(最开始先连接了一个nn.MaxPool2d(2,2),这是看论文时没有注意到的):
PC-DARTS里的候选操作:
OPS = {
'none' : lambda C, stride, affine: Zero(stride),
'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
}
最后分类层以及初始化参数 α \alpha α:
在回过头来看看网络参数的作用:
self._C是PC_DARTS的初始通道数,self._layers为堆叠的cells的数量,self._criterion为损失函数,self._steps为每个cells里blocks的数量,self._multiplier在forward()函数里才用到。
再看forward()函数:
先是通过self.stem0和self.stem1把通道扩大到48:
然后是cells里的运算顺序:
关于PC-DARTS里超参数 β \beta β 的作用,见(PC-DARTS) 3.3节。
cells里的forward()函数:
##############
补充:
torch.device
torch.device
代表将torch.Tensor
分配到的设备的对象。
torch.device
包含一个设备类型('cpu'
或'cuda'
设备类型)和可选的设备的序号。如果设备序号不存在,则为当前设备; 例如,torch.Tensor
用设备构建'cuda'
的结果等同于'cuda:X'
,其中X
是torch.cuda.current_device()
的结果。
##############
94行self.ops在 class Cell的__init__()函数里定义,即MixedOp的集合
MixedOp的forward()函数:
这部分代码也体现了论文中的Partial Channel Connections的思想。59行还进行了通道洗牌。
最后经过全局池化层和分类层:
再看训练代码:
如论文描述,SGD训练权重参数,Adam训练架构参数,SGD采用余弦退火学习率至0,详细看论文:
前5个epoch采用学习率warm-up
epoch>=args.begin (35)才开始训练架构参数:
训练模型权重: