torch.repeat
import torch
v_sum = torch.tensor([[1,3,0],[1,0,0]])
print(v_sum)
mask_index = torch.nonzero(v_sum == 0)
print(mask_index)
q = torch.rand([2,1,6])
print(q)
q_expand = q.repeat(1,3,1)
print(q_expand)
q_expand[mask_index[:, 0], mask_index[:, 1]] = 0
print(q_expand)
torch.expand
import torch
v_sum = torch.tensor([[1,3,0],[1,1,1]])
print(v_sum)
mask_index = torch.nonzero(v_sum == 0)
print(mask_index)
q = torch.rand([2,1,6])
print(q)
q_expand = q.expand(-1,3,-1)
print(q_expand)
q_expand[mask_index[:, 0], mask_index[:, 1]] = 0
print(q_expand)