torch常用函数

一、 torch.bmm

torch.bmm 是 PyTorch 中的一个函数,用于执行批矩阵乘法( matrix multiplication)操作。
它的输入是三维张量,形状为 (batch, n, m) 和 (batch, m, p):
其中 n 是第一个矩阵的列数,m 是两个矩阵共享的维度,p 是第二个矩阵的列数。
torch.bmm 将批中的每对矩阵相乘,返回一个新的三维张量,形状为 (batch, n, p)。

二、 torch.einsum

torch.einsum是pytorch上的一个强大的函数,用于矩阵相关的计算,注意,这里没有限定为矩阵乘法。torch.einsum基于爱因斯坦求和约定执行张量操作,能够用简洁的表达式实现复杂的多维数组操作,从而避免繁琐的张量操作组合(如reshape、permute、bmm等),减少错误率。需要说明的是,尽管einsum函数内部进行了大量计算优化,但其主要优势在于表达式简洁,如果与单步reshape等pytorch实现的矩阵运算操作相比,其运算速度与内存占用不一定占优势。

1.矩阵乘法:‘ij,jk->ik’ 表示形状为(i,j)与形状为(j,k)的矩阵进行矩阵乘法,得到新矩阵形状为(i,k)。这也是torch.einsum最常规的用法。

2.维度调换:'ij->ji’表示形状为(i,j)的矩阵维度调换成为形状为(j,i)的矩阵。
torch.einsum还有多种用法,遇到再来添加

三、python中变量前面有个*

在Python中,变量前面的星号(*)有多种用法,主要与函数参数或解包序列有关。

1、在函数参数中,星号(*)用来表示任意多个参数,这些参数会被当作元组传递。例如:

def fun(*args):
    for i in args:
        print(i)
 
fun(1, 2, 3, 4)

2、在函数参数中,星号(*)还可以用来解包序列。例如:

def fun(a, b, c, d):
    print(a, b, c, d)
 
args = (1, 2, 3, 4)
fun(*args)

3、在函数参数中,星号(*)还可以与命名参数,或者字典一起使用。例如:

def fun(*args, a=1):
    print(args, a)
 
fun(1, 2, 3, a=4)

def fun(*args, **kwargs):
    print(args, kwargs)
 
fun(1, 2, 3, a=4, b=5)

4、 在解包列表或元组时,星号(*)也可以用来解包选定项。例如:

lst = [1, 2, 3, 4, 5]
a, *b, c = lst
print(a, b, c)

四、numpy.prod

计算元素和

print(np.prod([[1., 2.], [3., 4.]], axis=0))按列计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=1))按行计算元素和
print(np.prod([[1., 2.], [3., 4.]], axis=0))计算所有元素和

五、torch.chunk

对于一个输入tensor,torch.chunk方法会按照dim指定的维度将输入tensor划分为若干个chunk,划分的数量为chunks。

torch.chunk(input, chunks, dim=0) 

temp=torch.randn((4,6))
print(torch.chunk(temp,2,0))行方向分块
print(torch.chunk(temp,2,1))列方向分块
  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值