expand()函数可以将张量广播到新的形状,但是切记以下两点:
- 只能对维度值为1的维度进行扩展,且扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回;
- 无需扩展的维度请保持维度值不变。
torch中的unsqueeze()
函数来增加一个维度,expand()
函数以行或列来广播。
# -*- encoding: utf-8 -*-
import torch
# 需求是对一个batch_size=2, seq_len=3的两个序列进行mask的扩展,
# 扩展为[batch_size, seq_len, 4, seq_len]
tokens = torch.tensor([[1,2, 3],[2,1,0]])
mask = tokens!=0
print(mask)
print(mask.shape)
print(mask.unsqueeze(2).shape)
print(mask.unsqueeze(2))
print(mask.unsqueeze(1).shape)
print(mask.unsqueeze(1))
multi = mask.unsqueeze(2)*mask.unsqueeze(1)
print('multi shape:',multi.shape) # [batch_size, seq, seq]
print(multi)
select = multi.unsqueeze(2)
print(select.shape) # batch, seq, 1, seq
print(select)
print(select.expand(-1,-1, 4, -1)) # expand的作用是把某个维度上为1的扩展为指定的个数
expand()
在行或列上的扩展
b shape: torch.Size([3, 1])
bb shape: torch.Size([3, 3])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
c shape: torch.Size([1, 3])
cc shape: torch.Size([3, 3])
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
result:
b shape: torch.Size([3, 1])
bb shape: torch.Size([3, 3])
tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
c shape: torch.Size([1, 3])
cc shape: torch.Size([3, 3])
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])