PC-DARTS代码笔记

PC-DARTS代码笔记

参考代码:https://github.com/yuhuixu1993/PC-DARTS

PC-DARTS网络的构造:train_search_imagenet.py

先看一些网络的参数,后在回来看这些参数的作用:

在这里插入图片描述

网络的开始先通过三个卷积层把通道扩充到C_curr = stem_multiplier*C=48

在这里插入图片描述

然后是cells堆叠构成多层网络。

在这里插入图片描述
Cell的代码:

在这里插入图片描述

DARTS每个cell有两个输入节点和一个输出节点,对于卷积单元来说,输入节点被定义为前两层的单元输出;当reduction_prev=True时,对特征图C_prev_prev进行下采样,可能作者是为了不影响感受野,用1x1卷积进行下采样,实现很微妙:

在这里插入图片描述

如果reduction_prev=False,则连接:

在这里插入图片描述

在这里插入图片描述

根据DARTS论文描述,**每个中间节点都是基于所有它之前的节点进行计算的,**对于代码为:

在这里插入图片描述

MixedOp的实现(最开始先连接了一个nn.MaxPool2d(2,2),这是看论文时没有注意到的):

在这里插入图片描述
PC-DARTS里的候选操作:

在这里插入图片描述

OPS = {
  'none' : lambda C, stride, affine: Zero(stride),
  'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
  'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
  'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
  'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
  'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
  'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
  'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
  'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
  'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential(
    nn.ReLU(inplace=False),
    nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
    nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
    nn.BatchNorm2d(C, affine=affine)
    ),
}

最后分类层以及初始化参数 α \alpha α

在这里插入图片描述

在这里插入图片描述

在回过头来看看网络参数的作用:

在这里插入图片描述
self._C是PC_DARTS的初始通道数,self._layers为堆叠的cells的数量,self._criterion为损失函数,self._steps为每个cells里blocks的数量,self._multiplier在forward()函数里才用到。

再看forward()函数:

先是通过self.stem0和self.stem1把通道扩大到48:

在这里插入图片描述

然后是cells里的运算顺序:

在这里插入图片描述

关于PC-DARTS里超参数 β \beta β 的作用,见(PC-DARTS) 3.3节。

cells里的forward()函数:

在这里插入图片描述
##############

补充:

torch.device

torch.device代表将torch.Tensor分配到的设备的对象。

torch.device包含一个设备类型('cpu''cuda'设备类型)和可选的设备的序号。如果设备序号不存在,则为当前设备; 例如,torch.Tensor用设备构建'cuda'的结果等同于'cuda:X',其中Xtorch.cuda.current_device()的结果。

##############

94行self.ops在 class Cell的__init__()函数里定义,即MixedOp的集合

MixedOp的forward()函数:

在这里插入图片描述

这部分代码也体现了论文中的Partial Channel Connections的思想。59行还进行了通道洗牌。

最后经过全局池化层和分类层:

在这里插入图片描述

再看训练代码:

如论文描述,SGD训练权重参数,Adam训练架构参数,SGD采用余弦退火学习率至0,详细看论文:

在这里插入图片描述

前5个epoch采用学习率warm-up

在这里插入图片描述

epoch>=args.begin (35)才开始训练架构参数:

在这里插入图片描述

训练模型权重:

在这里插入图片描述

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值