ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) RuntimeError: expected scalar type Long but found Float
将label转化为torch.LongTensor即可。
label = label.type(torch.LongTensor)
ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) RuntimeError: expected scalar type Long but found Float
将label转化为torch.LongTensor即可。
label = label.type(torch.LongTensor)