有趣的torch.einsum

import torch
import numpy as np
a = torch.arange(9).reshape(3, 3)

提取矩阵对角线元素

out = torch.einsum('ii->i', a)	# tensor([0, 4, 8])

矩阵转置

out = torch.einsum('ij->ji', a)
out = torch.einsum('...ij->...ji', a) # 高维矩阵最后两维转置

reduce sum

out = torch.einsum('ij->', a)	# tensor(36)

矩阵按列求和

out = torch.einsum('ki->i', a)

矩阵向量乘法

a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
out = torch.einsum('ik,k->i', a, b)
out = torch.einsum('ik,k', a, b)	# 箭头右侧符号可以不写,按规则默认推理。

矩阵乘法

a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)
out = torch.einsum('ik,kj->ij', a, b)
out = torch.einsum('ik,kj', a, b)

向量内积

a = torch.arange(3)
b = torch.arange(3, 6)
out = torch.einsum('i,i->', a, b)
out = torch.einsum('i,i', a, b)

矩阵元素对应相乘并求reduce sum

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
out = torch.einsum('ij,ij->', a, b)

向量外积

a = torch.arange(3)
b = torch.arange(3,7)
out = torch.einsum('i,j->ij', a, b)

batch矩阵乘法

a = torch.randn(2,3,5)
b = torch.randn(2,5,4)
out = torch.einsum('bik,bkj->bij', a, b)

张量收缩

tensor contraction, 用不上,暂时看不懂。

a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
out = torch.einsum('pqrs,tuqvr->pstuv', a, b)

双线性变换

bilinear transformation. Applies a bilinear transformation to the incoming data.

a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
out = torch.einsum('ik,jkl,il->ij', a, b, c)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值