这个是三维的:
import numpy as np
import torch
x = torch.linspace(1,27,steps=27).view(3,3,3)
y= torch.ones(2,2)#.numpy()#.astype(np.uint8)#.type(torch.uint8)
#[第一维][第二维][第三维]
y=np.asarray([[0,1,2],[2,0,1],[1,0,2]])
# y[0,1]=0
x[y]=0
print(x)
import torch
a=torch.linspace(1,8,steps=8).view(2,2,2)
#这个是正确的:
b=torch.linspace(1,4,steps=4).view(2,2).view(4,1)
c=torch.linspace(0,1,steps=2).repeat(2).view(4,1)
d=torch.cat((b,c),1).view(2,2,2).type(torch.uint8)
# print(a)
print(d)
# d[0,0,0]=1
a[d]=0
print(a)
d维度相同,则相同位置,如果>0则会选中,否则,不会选中
tensor([[[1, 0],
[2, 1]],
[[3, 0],
[4, 1]]], dtype=torch.uint8)
tensor([[[0., 2.],
[0., 0.]],
[[0., 6.],