深度学习张量拼接,索引操作

torch.cat 函数的使用

torch.cat 函数可以将两个张量根据指定的维度拼接起来.

import torch


def test():

    data1 = torch.randint(0, 10, [3, 5, 4])
    data2 = torch.randint(0, 10, [3, 5, 4])

    print(data1)
    print(data2)
    print('-' * 50)

    # 1. 按0维度拼接
    new_data = torch.cat([data1, data2], dim=0)
    print(new_data.shape)
    print('-' * 50)

    # 2. 按1维度拼接
    new_data = torch.cat([data1, data2], dim=1)
    print(new_data.shape)

    # 3. 按2维度拼接
    new_data = torch.cat([data1, data2], dim=2)
    print(new_data)


if __name__ == '__main__':
    test()

输出结果:

tensor([[[6, 8, 3, 5],
         [1, 1, 3, 8],
         [9, 0, 4, 4],
         [1, 4, 7, 0],
         [5, 1, 4, 8]],

        [[0, 1, 4, 4],
         [4, 1, 8, 7],
         [5, 2, 6, 6],
         [2, 6, 1, 6],
         [0, 7, 8, 9]],

        [[0, 6, 8, 8],
         [5, 4, 5, 8],
         [3, 5, 5, 9],
         [3, 5, 2, 4],
         [3, 8, 1, 1]]])
tensor([[[4, 6, 8, 1],
         [0, 1, 8, 2],
         [4, 9, 9, 8],
         [5, 1, 5, 9],
         [9, 4, 3, 0]],

        [[7, 6, 3, 3],
         [4, 3, 3, 2],
         [2, 1, 1, 1],
         [3, 0, 8, 2],
         [8, 6, 6, 5]],

        [[0, 7, 2, 4],
         [4, 3, 8, 3],
         [4, 2, 1, 9],
         [4, 2, 8, 9],
         [3, 7, 0, 8]]])
--------------------------------------------------
torch.Size([6, 5, 4])
--------------------------------------------------
torch.Size([3, 10, 4])
tensor([[[6, 8, 3, 5, 4, 6, 8, 1],
         [1, 1, 3, 8, 0, 1, 8, 2],
         [9, 0, 4, 4, 4, 9, 9, 8],
         [1, 4, 7, 0, 5, 1, 5, 9],
         [5, 1, 4, 8, 9, 4, 3, 0]],

        [[0, 1, 4, 4, 7, 6, 3, 3],
         [4, 1, 8, 7, 4, 3, 3, 2],
         [5, 2, 6, 6, 2, 1, 1, 1],
         [2, 6, 1, 6, 3, 0, 8, 2],
         [0, 7, 8, 9, 8, 6, 6, 5]],

        [[0, 6, 8, 8, 0, 7, 2, 4],
         [5, 4, 5, 8, 4, 3, 8, 3],
         [3, 5, 5, 9, 4, 2, 1, 9],
         [3, 5, 2, 4, 4, 2, 8, 9],
         [3, 8, 1, 1, 3, 7, 0, 8]]])

 torch.stack 函数的使用

torch.stack 函数可以将两个张量根据指定的维度叠加起来.

import torch


def test():

    data1= torch.randint(0, 10, [2, 3])
    data2= torch.randint(0, 10, [2, 3])
    print(data1)
    print(data2)

    new_data = torch.stack([data1, data2], dim=0)
    print(new_data.shape)

    new_data = torch.stack([data1, data2], dim=1)
    print(new_data.shape)

    new_data = torch.stack([data1, data2], dim=2)
    print(new_data)


if __name__ == '__main__':
    test()

输出结果:

tensor([[5, 8, 7],
        [6, 0, 6]])
tensor([[5, 8, 0],
        [9, 0, 1]])
torch.Size([2, 2, 3])
torch.Size([2, 2, 3])
tensor([[[5, 5],
         [8, 8],
         [7, 0]],

        [[6, 9],
         [0, 0],
         [6, 1]]])

简单行、列索引

准备数据

import torch

data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 50)

输出结果:

tensor([[0, 7, 6, 5, 9],
        [6, 8, 3, 1, 0],
        [6, 3, 8, 7, 3],
        [4, 9, 5, 3, 1]])
# 1. 简单行、列索引
def test01():

    print(data[0])
    print(data[:, 0])
    print('-' * 50)

if __name__ == '__main__':
    test01()

输出结果:

tensor([0, 7, 6, 5, 9])
tensor([0, 6, 6, 4])

列表索引

# 2. 列表索引
def test02():

    # 返回 (0, 1)、(1, 2) 两个位置的元素
    print(data[[0, 1], [1, 2]])
    print('-' * 50)

    # 返回 0、1 行的 1、2 列共4个元素
    print(data[[[0], [1]], [1, 2]])
if __name__ == '__main__':
    test02()

输出结果:

tensor([7, 3])
--------------------------------------------------
tensor([[7, 6],
        [8, 3]])

范围索引

# 3. 范围索引
def test03():
    # 前3行的前2列数据
    print(data[:3, :2])
    # 第2行到最后的前2列数据
    print(data[2:, :2])
if __name__ == '__main__':
    test03()

输出结果:

tensor([[0, 7],
        [6, 8],
        [6, 3]])
tensor([[6, 3],
        [4, 9]])

布尔索引

# 布尔索引
def test():

    # 第三列大于5的行数据
    print(data[data[:, 2] > 5])
    # 第二行大于5的列数据
    print(data[:, data[1] > 5])
if __name__ == '__main__':
    test04()

输出结果:

tensor([[0, 7, 6, 5, 9],
        [6, 3, 8, 7, 3]])
tensor([[0, 7],
        [6, 8],
        [6, 3],
        [4, 9]])

多维索引

# 多维索引
def test05():

    data = torch.randint(0, 10, [3, 4, 5])
    print(data)
    print('-' * 50)

    print(data[0, :, :])
    print(data[:, 0, :])
    print(data[:, :, 0])


if __name__ == '__main__':
    test05()

输出结果:

tensor([[[2, 4, 1, 2, 3],
         [5, 5, 1, 5, 0],
         [1, 4, 5, 3, 8],
         [7, 1, 1, 9, 9]],

        [[9, 7, 5, 3, 1],
         [8, 8, 6, 0, 1],
         [6, 9, 0, 2, 1],
         [9, 7, 0, 4, 0]],

        [[0, 7, 3, 5, 6],
         [2, 4, 6, 4, 3],
         [2, 0, 3, 7, 9],
         [9, 6, 4, 4, 4]]])
--------------------------------------------------
tensor([[2, 4, 1, 2, 3],
        [5, 5, 1, 5, 0],
        [1, 4, 5, 3, 8],
        [7, 1, 1, 9, 9]])
tensor([[2, 4, 1, 2, 3],
        [9, 7, 5, 3, 1],
        [0, 7, 3, 5, 6]])
tensor([[2, 5, 1, 7],
        [9, 8, 6, 9],
        [0, 2, 2, 9]])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值