最近看论文时看到了一个让我费解的操作。
二维矩阵 * 三维矩阵 * 二维矩阵 得到了一个二维矩阵。
即:
(n * c1) x (m * c1 * c2) x (n * c2) 得 n * m
实现主要参考的是tensorflow的matmul运算对于高维矩阵的乘法支持batch的操作,只要保证高维矩阵最后两维之前的维度一样就可以。直接上例子比较直观。
import tensorflow as tf g = tf.Graph() with g.as_default(): x = tf.ones([2, 3, 1], dtype=tf.float32) y = tf.ones([2, 1, 4], dtype=tf.float32) z = tf.matmul(x, y)
p <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span> q <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span> r <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>p<span class="token punctuation">,</span> q<span class="token punctuation">)</span>
with tf.compat.v1.Session(graph=g) as sess:
print(sess.run(z).shape) # (2, 3, 4)
print(sess.run(r).shape) # (2, 3, 1, 6)
比较让我震惊的是在tensorflow2.0版本可以按下面计算,当然这样计算比较符合理想化结果,例子中就是300个二维矩阵分别跟一个二维矩阵去乘。
numpy和torch也是支持这样计算的,但是numpy的结果的维度有所不同。
import tensorflow as tf g = tf.Graph() with g.as_default(): a = tf.ones([2, 3], dtype=tf.float32) b = tf.ones([300, 3, 6], dtype=tf.float32) d = tf.matmul(a, b) # (300,2,6),这一步2.0版本能够运行令人费解
d <span class="token operator">=</span> tf<span class="token punctuation">.</span>transpose<span class="token punctuation">(</span>d<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token comment"># d:(2,300,6)</span> c <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span> <span class="token comment"># 原本c应该是(2,6)</span> e <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>d<span class="token punctuation">,</span> c<span class="token punctuation">)</span> <span class="token comment"># e:(2,300,1)</span> e <span class="token operator">=</span> tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>e<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">300</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
with tf.compat.v1.Session(graph=g) as sess:
print(sess.run(e).shape)
# print(sess.run(d).shape)
tensorflow1.0版本不可以按上述计算,在第一个matmul的时候必须要将b reshape一下,具体计算可以参考:
https://blog.csdn.net/weixin_41024483/article/details/88536662