torch.repeat
import torch
v_sum = torch.tensor([[1,3,0],[1,0,0]])
print(v_sum)
mask_index = torch.nonzero(v_sum == 0)
print(mask_index)
q = torch.rand([2,1,6])
print(q)
q_expand = q.repeat(1,3,1)
print(q_expand)
q_expand[mask_index[:, 0], mask_index[:, 1]] = 0
print(q_expand)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/51328b575e6493ef16e08f456c9d5343.png)
torch.expand
import torch
v_sum = torch.tensor([[1,3,0],[1,1,1]])
print(v_sum)
mask_index = torch.nonzero(v_sum == 0)
print(mask_index)
q = torch.rand([2,1,6])
print(q)
q_expand = q.expand(-1,3,-1)
print(q_expand)
q_expand[mask_index[:, 0], mask_index[:, 1]] = 0
print(q_expand)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/8ffb5a24f662f5f38925a4b8963b61b4.png)