论文题目:Fair DARTS: Eliminating Unfair Advantages in Differentiable Architecture Search
论文链接:https://arxiv.org/pdf/1911.12126.pdf
论文代码:https://github.com/xiaomi-automl/FairDARTS
关于fair-darts论文解析见[NAS]Fair darts
fair-darts是基于darts的改进工作,代码也是基于darts进行的改进。关于darts代码的解析见[NAS]Darts代码解析。
在这里只介绍与darts不同的几个重要part
搜索空间
'''
去掉了none操作
'''
损失函数
'''
train_search.py
可选正则项 [l1,l2]
aux_loss_weight 论文默认是10.0
'''
criterion_train = ConvSeparateLoss(weight=args.aux_loss_weight) if args.sep_loss == 'l2' else TriSeparateLoss(weight=args.aux_loss_weight)
'''
separate_loss.py
'''
__all__ = ['ConvSeparateLoss', 'TriSeparateLoss']
# l2
class ConvSeparateLoss(nn.modules.loss._Loss):
"""Separate the weight value between each operations using L2"""
def __init__(self, weight=0.1, size_average=None, ignore_index=-100,reduce=None, reduction='mean'):
super(ConvSeparateLoss, self).__init__(size_average, reduce, reduction)
self.ignore_index = ignore_index
self.weight = weight
def forward(self, input1, target1, input2):
loss1 = F.cross_entropy(input1, target1)
loss2 = -F.mse_loss(input2, torch.tensor(0.5, requires_grad=False).cuda())
return loss1 + self.weight*loss2, loss1.item(), loss2.item()
# l1
class TriSeparateLoss(nn.modules.loss._Loss):
"""Separate the weight value between each operations using L1"""
def __init__(self, weight=0.1, size_average=None, ignore_index=-100