CrossEntropyLoss没有backward属性
错误代码如下,调用封装函数直接会返回CrossEntropyLoss函数相当于直接调用了CrossEntropyLoss的backward,所以报错,应该对CrossEntropyLoss的返回值使用backward
def CrossEntropyLoss_func(output, target):
return nn.CrossEntropyLoss()
loss = CrossEntropyLoss_func(output, target)
loss.backward()
正确代码如下,
def CrossEntropyLoss_func():
return nn.CrossEntropyLoss()
my_loss = CrossEntropyLoss_func()
loss = my_loss(output,target)
loss.backward()
#也就是
my_loss = nn.CrossEntropyLoss()
loss = my_loss(output,target)
loss.backward()