torch版本1.8.0
cuda10.2
在训练过程报错
RuntimeError: CUDA error: device-side assert triggered
定位到问题是模型运行过程中所有的张量数据变成了如下形式
<torch.Tensor object at 0x7f687a1dc440>
debug找到问题出在nn.Embedding层
import torch.nn as nn
import torch
emb = nn.Embedding(10,768)
index = torch.tensor([10])
# emb(index) 在cpu上能够正常报错out of range
emb = emb.to('cuda')
index = index.to('cuda')
emb(index) # cuda上该行可以运行
print(index)# 所有张量,包括Embedding的权重变为<torch.Tensor object at 0x7f687a1dc440>