很多开源工程经常会出现args,kwargs。本文将不定期更新博主解锁的kwargs用法
省略赋值
源工程地址
问题代码
class DC_and_CE_loss(nn.Module):
def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"):
super(DC_and_CE_loss, self).__init__()
self.aggregate = aggregate
self.ce = CrossentropyND(**ce_kwargs)
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)
def forward(self, net_output, target):
dc_loss = self.dc(net_output, target)
ce_loss = self.ce(net_output, target)
if self.aggregate == "sum":
result = ce_loss + dc_loss
else:
raise NotImplementedError("nah son") # reserved for other stuff (later)
return result
def CE_DiceLoss(self, logit, target):
criterion = DC_and_CE_loss()
if self.cuda:
criterion = criterion.cuda()
loss = criterion(logit, target)
return loss
上面的代码会报错,原因是ce_kwargs
和soft_dice_kwargs
需要在定义DC_and_CE_loss()
时被赋值,但是点进CrossentropyND
和SoftDiceLoss
发现,这两个类的所有参数都有缺省值。我想让DC_and_CE_loss()
所有参数都设置为缺省值。
最简单的解决方法:
criterion = DC_and_CE_loss()
改为criterion = DC_and_CE_loss(ce_kwargs={}, soft_dice_kwargs={})
,这里改成None也是行不通的。根本原因在于kwargs的本质是个字典,要让它为默认值,则赋值一个空字典