问题描述
—> 37 t_label = torch.zeros(32, 2).scatter_(1, t_label.data, 1)
38 t_label = Variable(t_label).cuda()
39 #t_label = Variable(t_label)
RuntimeError: index 4 is out of bounds for dimension 1 with size 2
原因分析:
数组越界。
解决方案:
可以改变数组大小。报错代码改为:
t_label = torch.zeros(32, 5).scatter_(1, t_label.data, 1)
#改zeros(32,2)为(32,5)