这个东西虽然计算起来真的方便的很多,但是对于人的理解难度是真的加大的,特别是高纬度的时候,例如:t.einsum(‘ijk,jkl->ijl’, [a,b])三维计算的时候。因此,最好的方法就是举个例子并且换一种方式来实现相同的功能(即循环),然后debug最容易理解。这也是在学习过程中一个好的解决问题的方式。
import torch as t
# 假设输入张量 x_rep 的形状为 [batch_size, n_rules, input_size]
# 假设参数张量 self.Cons 的形状为 [n_rules, input_size, output_size]
batch_size = 2
n_rules = 3
input_size = 4
output_size = 2
# 生成示例输入张量 x_rep 和参数张量 self.Cons
x_rep = t.tensor([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]
], dtype=t.float32)
self_Cons = t.tensor([[[0.1, 0.2],
[0.3, 0.4],
[0.5, 0.6],
[0.7, 0.8]],