版本问题
源代码:
test_loss += F.nll_loss(out_tgt.log(), target_label, size_average=False).data[0] # sum up batch loss
修改后代码:
test_loss += F.nll_loss(out_tgt.log(), target_label, size_average=False).item() # sum up batch loss
版本问题
源代码:
test_loss += F.nll_loss(out_tgt.log(), target_label, size_average=False).data[0] # sum up batch loss
修改后代码:
test_loss += F.nll_loss(out_tgt.log(), target_label, size_average=False).item() # sum up batch loss