方法同: https://www.jianshu.com/p/15b1b809074c
遇到问题报错:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/lib/python3/dist-packages/spyderlib/widgets/externalshell/sitecustomize.py", line 699, in runfile
execfile(filename, namespace)
File "/usr/lib/python3/dist-packages/spyderlib/widgets/externalshell/sitecustomize.py", line 88, in execfile
exec(compile(open(filename, 'rb').read(), filename, 'exec'), namespace)
File "/media/muya/32512c1a-a13e-4164-99e7-8dd2e82d0358/UDAimage/RES18/train.py", line 182, in <module>
train(args)
File "/media/muya/32512c1a-a13e-4164-99e7-8dd2e82d0358/UDAimage/RES18/train.py", line 98, in train
OneHotLabel = torch.zeros(args.batch_size, args.classnum).scatter_(1, label, 1)
RuntimeError: invalid argument 3: Index tensor must either be empty or have same dimensions as output tensor at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:533
原因: 我的标签是一维的
解决方法: 将label 转为二维
label = label.unsqueeze(-1)
Tips:
完整的将label转换onehot参考: https://blog.csdn.net/longshaonihaoa/article/details/105640239