pytorch常用方法详解

本文详细介绍了PyTorch库中几个关键的张量操作,包括bmm(批量矩阵乘法)、gather(按索引选取元素)、masked_fill(替换特定值)、matmul(多维矩阵乘法)、triu和tril(创建上/下三角矩阵),以及view(张量形状转换)。这些操作对于理解深度学习和数值计算至关重要。
摘要由CSDN通过智能技术生成

bmm

bmm 函数(批量矩阵乘法)用于 3 维张量的矩阵乘法。它的输入是两个形状为 (batch_size, n, m) 和 (batch_size, m, p) 的张量,输出是一个形状为 (batch_size, n, p) 的张量。

import torch
#A.shape=(2,2,3)
A = torch.tensor([[[1, 2, 3],
                  [4, 5, 6]],

                 [[7, 8, 9],
                  [10, 11, 12]]])
#B.shape=(2,3,2)                  
B = torch.tensor([[[1, 2],
                   [3, 4],
                   [5, 6],],

                  [[7, 8],
                   [9, 10],
                   [11, 12]]])
#C.shape=(2,2,2) 
C = torch.bmm(A, B)
'''
C:
tensor([[[ 22,  28],
         [ 49,  64]],

        [[220, 244],
         [301, 334]]])
'''

gather

详见:PyTorch中torch.gather()函数直观理解及结果速算

masked_fill

import torch
attn = torch.tensor([[0.1, 0.2, 0.3],
                     [0.4, 0.5, 0.6],
                     [0.7, 0.8, 0.9]])
mask = torch.tensor([[1, 1, 0],
                     [1, 1, 1],
                     [0, 1, 1]])
'''
找到张量mask中值等于0的位置,将张量attn中同样位置的值替换成 -1e9   
tips: attn 和 mask  的形状必须一致 
'''            
attn = attn.masked_fill(mask == 0, -1e9)
print(attn)

'''
结果为:
torch.tensor([[0.1, 0.2, -1e9],
              [0.4, 0.5, 0.6],
              [-1e9, 0.8, 0.9]])
'''

matmul

matmul(矩阵乘法)用于多维(≥2的任意维度)张量的乘法,而mm 仅适用于二维张量,bmm仅适用于三维张量。在运算时mamul会将参与计算的两个张量的后两个维度当做矩阵的形状,其余维度当做批量维度,进行具体的计算。

mm

mm 函数(矩阵乘法)用于 2 维张量的乘法。它的输入是两个形状为 (m, n) 和 (n, p) 的张量,输出是一个形状为 (m, p) 的张量。

triu、tril

使用 torch.triu 和 torch.tril 函数来创建上三角和下三角矩阵。这两个函数都接受一个二维张量作为输入,并返回一个保留原始张量中指定部分的上三角或下三角矩阵。
torch.triu 和 torch.tril 函数都接受一个可选参数 diagonal,用于指定要保留的三角矩阵的 diagonal 位置。默认情况下,diagonal 为 0,这意味着矩阵的主对角线上的元素将被保留。如果要将矩阵的对角线元素也设置为 0,可以将 diagonal 设置为相应的值。例如,torch.triu(t, diagonal=1) 将返回一个上三角矩阵,其中主对角线上的元素也被设置为 0。

import torch

# 创建一个 3x4 的全一的张量
t = torch.ones(3, 4,dtype=torch.int)

# 创建一个上三角矩阵
upper_triangular = torch.triu(t)

print(upper_triangular)
'''
upper_triangular:
tensor([[1, 1, 1, 1],
        [0, 1, 1, 1],
        [0, 0, 1, 1]], dtype=torch.int32)
'''

view

view 方法用于改变张量的形状。view 方法不会修改原始张量,而是返回一个新的张量,但新的张量与原始张量共享相同的基础数据。–PyTorch 允许张量是现有张量的视图。视图张量与其基本张量共享相同的基础数据。支持视图避免显式数据复制,因此可以快速且高效地进行重塑、切片和逐元素操作。具体可查看官方文档:PyTorch Docs > Tensor Views

torch.view 要求输入张量在内存中是连续的(contiguous),即存储顺序上连续无间断。如果张量不是连续的,torch.view 会失败并抛出错误。因此在调用 torch.view 之前,通常需要调用 .contiguous() 方法来使张量变为连续张量。

import torch
t = torch.rand(4, 4)
#判断张量是否连续
print(t.is_contiguous())
b = t.view(2, 8)

'''
storage() 方法返回的是底层存储数据的 Storage 对象, data_ptr() 方法返回当前对象的指针。
因t、b共享底层数据,所以结果为True
'''
print(t.storage().data_ptr() == b.storage().data_ptr())  
b[0][0] = 3.14
#t[0][0]的值也会变为3.14
t[0][0]
  • 6
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值