torch 删除tensor全0列
在pythorch中并没有直接的包可以删除元素,所以我们需要自己写
def del_tensor_0_cloumn(Cs):
idx = torch.where(torch.all(Cs[..., :] == 0, axis=0))[0]
all = torch.arange(Cs.shape[1])
for i in range(len(idx)):
all = all[torch.arange(all.size(0))!=idx[i]-i]
Cs = torch.index_select(Cs, 1, all)
return Cs
上述代码就是删除tensor的全0列