tensors used as indices must be long or byte tensors
import torch import numpy as np a=torch.zeros((4,3,3)) b=torch.ones((4,3)) print(a[b])
解决,
b=torch.ones((4,3)).type(torch.uint8)
索引类型只支持uint8或者int64
bytes 或者long类型
tensors used as indices must be long or byte tensors
import torch import numpy as np a=torch.zeros((4,3,3)) b=torch.ones((4,3)) print(a[b])
解决,
b=torch.ones((4,3)).type(torch.uint8)
索引类型只支持uint8或者int64
bytes 或者long类型