【torch.einsum】

本文介绍了如何使用PyTorch的torch.einsum进行矩阵转置、点乘和乘法操作,通过字符串表达式实现高效简洁的张量运算,包括矩阵转置的'ij->ji'、哈达玛积的'ij,ij->ij'以及矩阵乘法的'ik,kj->ij'示例。
摘要由CSDN通过智能技术生成

参考:https://www.cnblogs.com/mengnan/p/10319701.html

爱因斯坦简记法,能简洁表示各种矩阵向量的操作,例如矩阵转置、乘法、求和等等,pytorch中调用API为torch.einsum,第一个参数是字符串,表示对张量的操作,定义不同字符串表示不同操作

例子

矩阵转置

字符串为ij->ji,表示将张量中的元素ij(第i行第j列)变成元素ji

a = torch.tensor([
    [1, 2],
    [2, 3],
    [4, 5],
    [2, 6]
])
b = torch.einsum('ij->ji', a)
print(b)
>>> tensor([[1, 2, 4, 2],
            [2, 3, 5, 6]])
矩阵点乘(哈达玛积)

将矩阵对应位置的元素相乘,字符串ij,ij->ij表示将两个矩阵的ij元素相乘得到元素ij,貌似只能是相乘不能是相加?

a = torch.tensor([
    [1, 2],
    [2, 3],
    [4, 5],
    [2, 6]
])
b = torch.tensor([
    [2, 6],
    [2, 4],
    [3, 1],
    [2, 2]
])
c = torch.einsum('ij,ij->ij', a, b)
print(c)
>>> tensor([[ 2, 12],
	        [ 4, 12],
	        [12,  5],
	        [ 4, 12]])
矩阵乘法

将矩阵 a ∈ R M × N a\in R^{M\times N} aRM×N和矩阵 b ∈ R N × S b\in R^{N\times S} bRN×S相乘得到矩阵 c ∈ R M × S c\in R^{M\times S} cRM×S,字符串ik,kj->ij表示将矩阵a的ik元素和矩阵b的kj元素相乘,k是遍历变量,遍历所有k,累加ik*kj累加得到ij,即矩阵乘法

a = torch.tensor([
    [1, 2, 4],
    [2, 3, 5],
])
b = torch.tensor([
    [2, 6],
    [2, 4],
    [3, 1],
])
c = torch.einsum('ik,kj->ij', a, b)
print(c)
>>> tensor([[18, 18],
        	[25, 29]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值