Pytorch学习-2

张量拼接:

在PyTorch中,张量拼接主要有两种方法:torch.cat()torch.stack()。这两种方法在功能上有所不同,但都用于将多个张量合并为一个张量。

1. torch.cat()

torch.cat(tensors, dim=0) 方法用于在给定维度上连接张量序列。这意味着所有被拼接的张量的形状(除了拼接的维度)必须相同。

  • 参数

    • tensors:一个张量序列。
    • dim:要拼接的维度。
  • 示例

    # 沿着行(维度0)拼接两个张量
    x = torch.tensor([[1, 2], [3, 4]])
    y = torch.tensor([[5, 6], [7, 8]])
    result = torch.cat([x, y], dim=0)
    print(result)
    # 输出:
    # tensor([[1, 2],
    #         [3, 4],
    #         [5, 6],
    #         [7, 8]])
    

    2. torch.stack()

    torch.stack(tensors, dim=0) 方法用于在新创建的维度上堆叠张量序列。这意味着所有被堆叠的张量必须具有完全相同的形状,torch.stack 添加的新维度将成为张量的一部分。

  • 参数

    • tensors:一个张量序列。
    • dim:新维度插入的位置。
  • 示例

    # 在新创建的维度上堆叠两个张量
    x = torch.tensor([1, 2])
    y = torch.tensor([3, 4])
    result = torch.stack([x, y], dim=0)
    print(result)
    # 输出:
    # tensor([[1, 2],
    #         [3, 4]])
    

    torch.cat 中,张量是沿着已存在的某个维度进行连接的,而 torch.stack 则是在张量序列之间添加一个新的维度进行堆叠,因此堆叠后的张量维度会增加。

    张量切分

在 PyTorch 中,张量切分主要通过两种方法实现:torch.chunk()torch.split()。这两种方法都将一个张量分割成多个较小的张量,但它们在如何定义分割方式上有所不同。

1. torch.chunk()

torch.chunk(input, chunks, dim=0) 方法将张量分割成指定数量的块。如果原始张量在指定维度上不能均等分割成目标数量的块,则最后一个块会小于其他块。

  • 参数

    • input:要分割的张量。
    • chunks:希望得到的块的数量。
    • dim:要分割的维度。
  • 示例

    x = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
    chunks = torch.chunk(x, 2, dim=0)
    for i, chunk in enumerate(chunks):
        print(f"Chunk {i}:\n {chunk}")
    # 输出:
    # Chunk 0:
    #  tensor([[1, 2],
    #         [3, 4]])
    # Chunk 1:
    #  tensor([[5, 6],
    #         [7, 8]])
    

    2. torch.split()

    torch.split(tensor, split_size_or_sections, dim=0) 方法根据给定的大小分割张量。split_size_or_sections 可以是单个整数,指定每个分割块的大小,或者是一个整数列表,指定每个块的具体大小。

  • 参数

    • tensor:要分割的张量。
    • split_size_or_sections:单个整数时,每个块的大小;整数列表时,每个块的具体大小。
    • dim:要分割的维度。
  • 示例(使用单个大小值):

    x = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
    splits = torch.split(x, 2, dim=0)
    for i, split in enumerate(splits):
        print(f"Split {i}:\n {split}")
    # 输出:
    # Split 0:
    #  tensor([[1, 2],
    #         [3, 4]])
    # Split 1:
    #  tensor([[5, 6],
    #         [7, 8]])
    

    示例(使用大小列表):

    x = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
    splits = torch.split(x, [2, 1, 2], dim=0)
    for i, split in enumerate(splits):
        print(f"Split {i}:\n {split}")
    # 输出:
    # Split 0:
    #  tensor([[1, 2],
    #         [3, 4]])
    # Split 1:
    #  tensor([[5, 6]])
    # Split 2:
    #  tensor([[7, 8],
    #         [9, 10]])
    

    选择 torch.chunk() 还是 torch.split() 取决于你的具体需求:如果你需要将张量分割成大小大致相等的几部分,使用 torch.chunk();如果你需要精确控制每个分割块的大小,使用 torch.split()

 张量索引

在PyTorch中,张量索引允许你选择、修改、提取数据子集的功能非常强大且灵活。下面详细介绍张量索引的主要方法:

1. 基本索引和切片

与Python列表和NumPy数组相似,PyTorch张量也支持基本的索引和切片操作。

  • 示例
    x = torch.arange(10)  # 创建一个包含0到9的一维张量
    print(x[2:5])  # 从位置2到4的元素
    print(x[:5])  # 开始到位置4的元素
    print(x[5:])  # 位置5到结束的元素
    print(x[-1])  # 最后一个元素
    

    2. 高级索引

    高级索引允许你使用张量来选择数据。

  • 整数数组索引:可以使用整数数组(列表或张量)进行索引,选择不连续的张量元素。

    y = torch.arange(12).view(3, 4)  # 创建一个形状为3x4的二维张量
    print(y[[0, 2], [2, 3]])  # 选择(0,2)和(2,3)位置上的元素
    

    布尔索引:使用布尔表达式可以选择符合特定条件的元素。

    print(y[y > 5])  # 选择所有大于5的元素

3. 选择操作(Fancy Indexing)

Fancy Indexing允许你使用整数列表进行索引。

  • 示例
    print(y[:, [1, 3]])  # 选择所有行的第1和第3列
    

4. 掩码操作

掩码操作允许你根据条件选择元素,这通常用于赋值或者提取子集。

mask = y > 5  # 创建一个布尔掩码
print(y[mask])  # 使用布尔掩码选择元素

5. 维度变换

PyTorch提供了许多操作来改变张量的形状,这虽然不严格是索引,但在处理张量时非常有用。

view:返回一个新的张量,具有相同的数据但大小不同。

z = torch.arange(10)
print(z.view(2, 5))  # 重塑为2x5的张量

unsqueeze:在指定位置添加一个维度。

print(z.unsqueeze(0))  # 在第0维添加一个维度

squeeze:移除所有长度为1的维度。

print(z.unsqueeze(0).squeeze(0))  # 移除第0维的单维度

6. 转置操作

转置可以改变张量的形状和维度顺序。

  • t:对二维张量进行转置。

    a = torch.arange(6).view(2, 3)
    print(a.t())
  • permute:对多维张量进行维度交换。

    b = torch.arange(24).view(2, 3, 4)
    print(b.permute(2, 0, 1))  # 交换维度

  • 24
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值