张量的操作:拼接、切分、索引和变换

张量的操作:拼接、切分、索引和变换

1 张量的操作:拼接、切分、索引和变换
一 张量的拼接和切分
在这里插入图片描述
1.1 torch.cat()
功能:将张量按维度dim进行拼接
tensor:张量序列
dim:拼接维度

t=torch.ones((2,3))
torch.cat([t,t],dim=0)
torch.cat([t,t],dim=1)
torch.cat([t,t,t],dim=1)

在这里插入图片描述
1.2 torch.stack()
功能:在新创建的维度dim上进行拼接
tensor:张量序列
dim:要拼接的维度

在这里插入图片描述
与cat相比,stack创建在了一个新维度

1.3 torch.chunk()
功能:将张量按维度dim进行平均切分
返回值:张量列表
注意:若不能整除,最后一份张量小于其他张量
在这里插入图片描述
input:要切分的张量
chunks:要切分的份数
dim:要切分的维度

a = torch.ones((2, 7))  # 7
list_of_tensors = torch.chunk(a, dim=1, chunks=3)   # 3

for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

在这里插入图片描述
1.4 torch.split()
在这里插入图片描述
功能:将张量按维度dim进行切分
返回值:张量列表
split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分

    t = torch.ones((2, 5))

    list_of_tensors = torch.split(t, 2], dim=1)  # [2 , 1, 2]
    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

在这里插入图片描述

    t = torch.ones((2, 5))

    list_of_tensors = torch.split(t, [2, 1, 1], dim=1)  # [2 , 1, 2]
    for idx, t in enumerate(list_of_tensors):
        print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

在这里插入图片描述
list内元素之和等于维度上的长度

二张量的索引
2.1 torch.index_select()
功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量
在这里插入图片描述
input:要索引的张量
dim:要索引的维度
index:要索引数据的序号

    t = torch.randint(0, 9, size=(3, 3))
    idx = torch.tensor([0, 2], dtype=torch.long)    # float
    t_select = torch.index_select(t, dim=0, index=idx)
    print("t:\n{}\nt_select:\n{}".format(t, t_select))

在这里插入图片描述
2.2 torch.masked_select()
功能:按mmask中的True进行索引
返回值:一维张量
input:要索引的张量
mask:与input同形状的布尔类型张量
在这里插入图片描述

    t = torch.randint(0, 9, size=(3, 3))
    mask = t.le(5)  # ge is mean greater than or equal/   gt: greater than  le  lt
    t_select = torch.masked_select(t, mask)
    print("t:\n{}\nmask:\n{}\nt_select:\n{} ".format(t, mask, t_select))

在这里插入图片描述
三 张量变换
torch.reshape()
功能:变换张量的形状
注意:当张量在内存中是连续时,新张量与input共享内存
input:要变换的张量
shape:新张量的形状
在这里插入图片描述

    t = torch.randperm(8)
    t_reshape = torch.reshape(t, (-1, 2, 2))    # -1
    print("t:{}\nt_reshape:\n{}".format(t, t_reshape))

在这里插入图片描述
共享内存
在这里插入图片描述
3.2 torch.transpose().
功能:交换维度
在这里插入图片描述
3.3 torch.t()
功能:2维张量转置
在这里插入图片描述
3.4 torch.squeeze()
功能:压缩长度为1的维度(轴)
dim:若为None,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,课被移除

3.5 torch.unsqueeze()
功能:依据dim扩展维度
在这里插入图片描述

    t = torch.rand((1, 2, 3, 1))
    t_sq = torch.squeeze(t)
    t_0 = torch.squeeze(t, dim=0)
    t_1 = torch.squeeze(t, dim=1)
    print(t.shape)
    print(t_sq.shape)
    print(t_0.shape)
    print(t_1.shape)

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值