torch.matmul()
:标准的矩阵乘法
-
向量-向量(点积)
a = torch.randn(3) # [3] b = torch.randn(3) # [3] c = torch.matmul(a, b) # 点积,标量输出
-
矩阵-向量
A = torch.randn(3, 4) # [3, 4] x = torch.randn(4) # [4] y = torch.matmul(A, x) # [3]
-
矩阵-矩阵
A = torch.randn(3, 4) # [3, 4] B = torch.randn(4, 5) # [4, 5] C = torch.matmul(A, B) # [3, 5]
-
批量矩阵乘法(更高维张量)
A = torch.randn(2, 3, 4) # [B, M, K] B = torch.randn(2, 4, 5) # [B, K, N] C = torch.matmul(A, B) # [B, M, N]
torch.einsum
:爱因斯坦求和约定(更通用的张量运算工具)
-
矩阵乘法
A = torch.randn(3, 4) B = torch.randn(4, 5) C = torch.einsum("ik,kj->ij", A, B) # 等价于 A @ B A = torch.randn(2, 3, 4) # [B, M, K] B = torch.randn(2, 4, 5) # [B, K, N] C = torch.einsum("bik,bkj->bij", A, B) # [B, M, N] a = torch.randn(3) b = torch.randn(3) c = torch.einsum("i,i->", a, b) # 点积,标量输出
-
转置
A = torch.randn(3, 4) B = torch.einsum("ij->ji", A) # 等价于 A.T
-
对角线提取
-
张量收缩(Tensor Contraction)(高阶张量乘法)
A = torch.randn(2, 3, 4, 5) B = torch.randn(2, 4, 5, 6) C = torch.einsum("abcd,abde->abce", A, B) # 对 d 维度收缩
-
广播运算
torch.matmul | torch.einsum | |
---|---|---|
灵活性 | 仅支持矩阵乘法类操作 | 支持任意张量运算(转置、收缩等) |
可读性 | 直观(A @ B ) | 需要熟悉爱因斯坦求和约定 |
性能 | 高度优化(推荐用于标准矩阵乘法) | 灵活但可能稍慢 |
广播支持 | 是 | 是 |
批量处理 | 自动支持 | 需显式指定批量维度 |