[NAS]Fair darts代码解析

论文题目:Fair DARTS: Eliminating Unfair Advantages in Differentiable Architecture Search
论文链接:https://arxiv.org/pdf/1911.12126.pdf
论文代码:https://github.com/xiaomi-automl/FairDARTS


关于fair-darts论文解析见[NAS]Fair darts
fair-darts是基于darts的改进工作,代码也是基于darts进行的改进。关于darts代码的解析见[NAS]Darts代码解析

在这里只介绍与darts不同的几个重要part


搜索空间
'''
去掉了none操作
'''

损失函数
'''
train_search.py
可选正则项 [l1,l2]
aux_loss_weight 论文默认是10.0
'''
criterion_train = ConvSeparateLoss(weight=args.aux_loss_weight) if args.sep_loss == 'l2' else TriSeparateLoss(weight=args.aux_loss_weight)
  
'''
separate_loss.py

'''
__all__ = ['ConvSeparateLoss', 'TriSeparateLoss']
# l2
class ConvSeparateLoss(nn.modules.loss._Loss):
	"""Separate the weight value between each operations using L2"""
    def __init__(self, weight=0.1, size_average=None, ignore_index=-100,reduce=None, reduction='mean'):
        super(ConvSeparateLoss, self).__init__(size_average, reduce, reduction)
        self.ignore_index = ignore_index
        self.weight = weight

    def forward(self, input1, target1, input2):
        loss1 = F.cross_entropy(input1, target1)
        loss2 = -F.mse_loss(input2, torch.tensor(0.5, requires_grad=False).cuda())
        return loss1 + self.weight*loss2, loss1.item(), loss2.item()

# l1
class TriSeparateLoss(nn.modules.loss._Loss):
    """Separate the weight value between each operations using L1"""
    def __init__(self, weight=0.1, size_average=None, ignore_index=-100
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值