pred = torch.max(a,1,keepdim=True)[1]
TypeError: torch.max received an invalid combination of arguments - got (torch.LongTensor, int, keepdim=bool), but expected one of:
* (torch.LongTensor source)
* (torch.LongTensor source, torch.LongTensor other)
* (torch.LongTensor source, int dim)
TypeError: torch.max received an invalid combination of arguments - got (torch.LongTensor, int, keepdim=bool), but expected one of:
* (torch.LongTensor source)
* (torch.LongTensor source, torch.LongTensor other)
* (torch.LongTensor source, int dim)
didn't match because some of the keywords were incorrect: keepdim
上面错误解决办法
把原先的代码
pred = output.data.max(1, keepdim=True)[1]
改为
pred = torch.max(output.data,1)[1]
参考文章