jnp.einsum()
该命令用于在jax中进行张量运算.
假设a是一个形状为(2,3,4)的三维张量. b是一个形状为(3,4)的二维张量. 那么**jnp.einsum(‘imn,mj->ijn’, a, b)**用来计算结果张量.
其中imn表示a张量的三个维度2,3,4;mj表示b张量的两个维度3,4;由于前面有imn,mj,后面是ijn,没有了m,所以是对m求和,所以遍历维度m来计算输出张量的每个元素.从而输出张量的形状是(2,4,4).
例如:
import jax.numpy as jnp
a = jnp.array([
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
],
[
[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]
]
])
b = jnp.array([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
])
output = jnp.einsum('imn,mj->ijn', a, b)
print(output) #output= [[[ 80 92 104 116]
#[ 92 107 122 137]
#[104 122 140 158]
#[116 137 158 179]]
# [[224 236 248 260]
# [272 287 302 317]
# [320 338 356 374]
# [368 389 410 431]]]