项目场景:
手搓CNN
问题描述:
提示:这里描述项目中遇到的问题:
出错代码:
batch_loss = loss(train_pred, data[1].cuda())
data[1] 是数据的标签
出错提示:
RuntimeError: multi-target not supported at C:/cb/pytorch_1000000000000/work/aten/src\THCUNN/generic/ClassNLLCriterion.cu:15
原因分析:
提示:这里填写问题的分析:
维度不一致,
这里需要标签的一维数组
即[n]
解决方案:
提示:
改成一维数组
改正代码:
target = data[1].cuda()
target = target.squeeze()
batch_loss = loss(train_pred, target.cuda())