import torch mask1 = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) mask2 = torch.tensor([[1, 0, 0], [0, 1, 1], [1, 0, 1]]) mask3 = torch.tensor([[0, 0, 1], [1, 1, 0], [0, 1, 0]]) # 将掩码堆叠到一个新的张量中,dim=0 表示在第一个维度上堆叠 stacked_masks = torch.stack([mask1, mask2, mask3], dim=1) print(stacked_masks) print(stacked_masks.size())
dim=0时:
简单的把三个tensor叠加
dim =1时:
取每个tensor的第一行组成第一个tensor,以此类推
dim=2时:
每个tensor进行转置