这个问题解决了一天,最后都要退学了,问了问师兄,原来是这个原因啊~
首先最初的问题是在cuda模式下运行代码出现的,报错如下:
然后这样是看不出来什么问题的,我们转换到cpu设备上运行,报错如下:
IndexError:index out of range in self。
请记住这个报错, IndexError:index out of range in self。再精确的定位到哪句话报错的,一般是torch的embedding出问题了。张量的输入超出了embedding的合法范围,应该在[0, num_embeddings -1 ]之内。否则会报错。
那么如何解决这个问题呢,首先你要找到什么时候定义嵌入层的,一般是model内部,我采用的方法是在定义嵌入层的时候将num_embeddings开大一点。就比如:
我是这里报错:输入张量src_quadkey里边的值有的大于num_embeddings,所以会报错。
#这里报错!!!!
src_quadkey_emb = self.emb_quadkey(src_quadkey)
这是我定义嵌入层的地方,n_quadkey就是上边的num_embeddings,这里我们调大一些,应该不会有啥影响吧,反正我就这么干的,你可以试试,运行一下就好了!
self.emb_loc = Embedding(n_loc, features, True, True)#地点id
self.emb_quadkey = Embedding(n_quadkey+1000, features, True, True)#quadkey的id
按理论来说,这里的只需要扩大一些num_embeddings就好了,但是我觉得根本原因在于你输入的张量为什么会比num_embeddings大,你可以试着debug一下。