问题描述:
my_loss_cell = WithLossCell(net, loss_fn) model = Model(my_loss_cell, optimizer = optim)
如果写成上面的形式,则不会报错。
withlosscell在代码中重写了一次。
class WithLossCell(nn.Cell): def __init__(self, backbone, loss_fn): super(WithLossCell, self).__init__(auto_prefix=False) self._backbone = backbone self._loss_fn = loss_fn def construct(self, data, label, classes): out = self._backbone(data) return self._loss_fn(out, label, classes)
但是我想要自定义一个metric,于是当我改成
model = Model(my_loss_cell, optimizer = optim, metrics={"PrototyAccLoss": PrototyAccLoss()})
PrototyAccLoss是我自己定义的一个metrics类。
就会报下面的错误
这个应该时说我在定义Model的时候没有加上loss_fn的参数,
但是当我把net和loss分开而不使用withlosscell的时候,则会报下面的错误:
这个应该是图模式才会出现的问题。
解决方案:
计算评估指标需要logits和label,在您自定义WithLossCell的情况下,Model不能分辨出哪个输出是logits,哪个是label。因此需要您指定eval_network,并输出用于计算metrics的logits和label。