从零开始学习pytorch之张量的操作:拼接、切分、索引和变换

张量的拼接

torch.cat(tensors, dim=0, out=None)

功能:将张量按维度dim进行拼接
·tensors:张量序列
·dim:要拼接的维度

import torch
t = torch.ones((2,3))
t_0 = torch.cat([t,t], dim=0)
t_1 = torch.cat([t,t], dim=1)
print('t_0:{} shape:{}\nt_1:{} shape:{}'.format(t_0,t_0.shape,t_1,t_1.shape))

在这里插入图片描述

torch.stack(tensors, dim=0, out=None

功能:在新创建的维度dim上进行拼接
·tensors:张量序列
·dim:要拼接的维度

t = torch.ones((2,3))
t_stack1 = torch.stack([t,t], dim=0)
t_stack2 = torch.stack([t,t], dim=2)
print('t_0:{} shape:{}\nt_1:{} shape:{}'.format(t_stack1,t_stack1.shape,t_stack2,t_stack2.shape))

在这里插入图片描述

张量的切分

torch.chunk(input, chunks, dim=0)

功能:将张量按维度dim进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其他张量
·input:要切分的张量
·chunks:要切分的份数
·dim:要切分的维度

t = torch.ones((2,5))
list_of_tensors = torch.chunk(t, dim=1, chunks=2)
for idx, mat in enumerate(list_of_tensors):
    print('第{}个张量:{}, 维度为{}'.format(idx+1,mat,mat.shape))

在这里插入图片描述

torch.split(tensor, split_size_or_sections, dim=0)

功能:将张量安慰度dim进行切分
返回值:张量列表
·tensor:要切分的张量
·split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分
·dim:要切分的维度

t = torch.ones((2,5))
list_of_tensors = torch.split(t, [2,1,2], dim=1)
for idx, mat in enumerate(list_of_tensors):
    print('第{}个张量:{}, 维度为{}'.format(idx+1,mat,mat.shape))

在这里插入图片描述

张量索引

torch.index_select(input,dim=0,index=None)

功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量
·input:要索引的张量
·dim:要索引的维度
·index:要索引数据的序号

t = torch.randint(0,9,size=(3,3))
idx = torch.tensor([0,2], dtype=torch.long) #float
t_select = torch.index_select(t, dim=0, index=idx)
print('{}\n{}'.format(t, t_select))

在这里插入图片描述

torch.masked_select(input, mask, out=None)

功能:按mask中的True进行索引
返回值:一维张量
·input:要索引的张量
·mask:与input同形状的布尔类型张量

t = torch.randint(0,9,size=(3,3))
#返回大小为t的矩阵,其中大于等于5的元素为True,小于5的为False
mask = t.ge(5) 
t_select = torch.masked_select(t, mask)
print('t:\n{}\nmask:\n{}\nt_select:\n{}'.format(t,mask,t_select))

在这里插入图片描述

张量变换

torch.reshape(input, shape)

功能:变换张量形状
注意事项:当张量在内存中是连续时,新张量与input共享数据内存
·input:要变换的张量
·shape:新张量的形状

t = torch.randperm(8)
t_reshape = torch.reshape(t, (-1,2,2))
print('t:\n{}\nt_reshape:\n{}'.format(t, t_reshape))
print('t内存地址{}'.format(id(t.data)))
print('t_reshape内存地址{}'.format(id(t_reshape.data)))

在这里插入图片描述

torch.transpose(input, dim0, dim1)

功能:交换张量的两个维度
·input:要变换的张量
·dim0:要变换的维度
·dim1:要变换的维度

t = torch.rand((2,3,4))
t_transpose = torch.transpose(t, dim0=1, dim1=2)
print('t shape:{} t_transpose shape:{}'.format(t.shape, t_transpose.shape))

在这里插入图片描述

torch.t(input)

功能:2维张量转置,对矩阵而言,等价于torch.transpsoe(input,0,1)

torch.squeeze(input, dim=None, out=None)

功能:压缩长度为1的维度(轴)
·dim:若为None,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除;

t = torch.rand((1,2,3,1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t, dim=0)
t_1 = torch.squeeze(t, dim=1)
print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)#第二个维度是2故无法压缩掉

在这里插入图片描述

torch.usqueeze(input, dim, out=None)

功能:依据dim扩展维度
·dim:扩展的维度

t = torch.rand((1,2,3))
t_sq = torch.unsqueeze(t,dim=3)
print(t.shape)
print(t_sq.shape)

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值