pytorch——unsqueeze与expand

expand()函数可以将张量广播到新的形状,但是切记以下两点:

  1. 只能对维度值为1的维度进行扩展,且扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回;
  2. 无需扩展的维度请保持维度值不变。在这里插入图片描述

torch中的unsqueeze()函数来增加一个维度,expand()函数以行或列来广播。

# -*- encoding: utf-8 -*-
import torch

# 需求是对一个batch_size=2, seq_len=3的两个序列进行mask的扩展,
# 扩展为[batch_size, seq_len, 4, seq_len]
tokens = torch.tensor([[1,2, 3],[2,1,0]])
mask = tokens!=0
print(mask)
print(mask.shape)

print(mask.unsqueeze(2).shape)
print(mask.unsqueeze(2))
print(mask.unsqueeze(1).shape)
print(mask.unsqueeze(1))

multi = mask.unsqueeze(2)*mask.unsqueeze(1)
print('multi shape:',multi.shape) # [batch_size, seq, seq]
print(multi)

select = multi.unsqueeze(2)
print(select.shape) # batch, seq, 1, seq
print(select)
print(select.expand(-1,-1, 4, -1)) # expand的作用是把某个维度上为1的扩展为指定的个数
  • expand()在行或列上的扩展
b shape: torch.Size([3, 1])
bb shape: torch.Size([3, 3])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
c shape: torch.Size([1, 3])
cc shape: torch.Size([3, 3])
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

result:

b shape: torch.Size([3, 1])
bb shape: torch.Size([3, 3])
tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]])
c shape: torch.Size([1, 3])
cc shape: torch.Size([3, 3])
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])
  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值