索引类型只支持uint8或者int64
bytes 或者long类型
import torch
import numpy as np
a=torch.zeros((4,3,3))
b=torch.ones((4,3)).type(torch.uint8)
print(a[b])
结果是12,3
# import numpy as np
#
# a = np.zeros((80,10, 3))
# # a=np.array([1,2,3])
# b = np.zeros((80,10),dtype=np.int32)
#
# print(a[b])
import torch
import numpy as np
a=torch.zeros((4,3))
b=np.asarray([[1,2],[1,2],[0,1]])
b=torch.ones((4,3)).type(torch.uint8)
# b+=1
c=[1,2]
print(a[b])
a[a<=0]=1
print(a)