最近在初学pytorch,然后在运行代码的时候出现了这个错误:
import torch as t
a = t.arange(0,16).view(4,4)
index = t.LongTensor([[0,1,2,3],[3,2,1,0]]).t()
b = a.gather(1,index)
c = t.zeros(4,4)
c.scatter_(1,index,b)
print(c)
----------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-4-b907583c1f90> in <module>
4 b = a.gather(1,index)
5 c = t.zeros(4,4)
----> 6 c.scatter_(1,index,b)
7 print(c)
RuntimeError: scatter_cpu_(): Expected self.dtype to be equal to src.dtype
后来去了下pytorch forums(那里听说氛围不错,所以我就抱着尝试的心态发了个求助帖),经过指点发现,原来这个错误是指的scatter_()函数需要实参的dtype一致,而t.arange()方法生成的tensor是torch.int64的,t.LongTensor(),t.randn(),t.zeros()生成的tensor是torch.float32的。所以只需要在这一句末尾加上.float()就可以正常运行了。