tensorflow 二维矩阵乘以三维矩阵,高维矩阵相乘

最近看论文时看到了一个让我费解的操作。
二维矩阵 * 三维矩阵 * 二维矩阵 得到了一个二维矩阵。
即:

(n * c1) x (m * c1 * c2) x (n * c2) 得 n * m

实现主要参考的是tensorflow的matmul运算对于高维矩阵的乘法支持batch的操作,只要保证高维矩阵最后两维之前的维度一样就可以。直接上例子比较直观。

import tensorflow as tf
g = tf.Graph()
with g.as_default():
    x = tf.ones([2, 3, 1], dtype=tf.float32)
    y = tf.ones([2, 1, 4], dtype=tf.float32)
    z = tf.matmul(x, y)
p <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span>
q <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">,</span> <span class="token number">5</span><span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span>
r <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>p<span class="token punctuation">,</span> q<span class="token punctuation">)</span>

with tf.compat.v1.Session(graph=g) as sess:
print(sess.run(z).shape) # (2, 3, 4)
print(sess.run(r).shape) # (2, 3, 1, 6)

比较让我震惊的是在tensorflow2.0版本可以按下面计算,当然这样计算比较符合理想化结果,例子中就是300个二维矩阵分别跟一个二维矩阵去乘。
numpy和torch也是支持这样计算的,但是numpy的结果的维度有所不同。

import tensorflow as tf
g = tf.Graph()
with g.as_default():
    a = tf.ones([2, 3], dtype=tf.float32)
    b = tf.ones([300, 3, 6], dtype=tf.float32)
    d = tf.matmul(a, b)  # (300,2,6),这一步2.0版本能够运行令人费解
d <span class="token operator">=</span> tf<span class="token punctuation">.</span>transpose<span class="token punctuation">(</span>d<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">0</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">]</span><span class="token punctuation">)</span>  <span class="token comment"># d:(2,300,6)</span>
c <span class="token operator">=</span> tf<span class="token punctuation">.</span>ones<span class="token punctuation">(</span><span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">6</span><span class="token punctuation">,</span> <span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">,</span> dtype<span class="token operator">=</span>tf<span class="token punctuation">.</span>float32<span class="token punctuation">)</span>  <span class="token comment"># 原本c应该是(2,6)</span>

e <span class="token operator">=</span> tf<span class="token punctuation">.</span>matmul<span class="token punctuation">(</span>d<span class="token punctuation">,</span> c<span class="token punctuation">)</span>  <span class="token comment"># e:(2,300,1)</span>
e <span class="token operator">=</span> tf<span class="token punctuation">.</span>reshape<span class="token punctuation">(</span>e<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">,</span> <span class="token number">300</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

with tf.compat.v1.Session(graph=g) as sess:
print(sess.run(e).shape)
# print(sess.run(d).shape)

tensorflow1.0版本不可以按上述计算,在第一个matmul的时候必须要将b reshape一下,具体计算可以参考:
https://blog.csdn.net/weixin_41024483/article/details/88536662

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在Matlab中,二维矩阵三维图通常使用的函数是“surf”和“mesh”。这两个函数都可以用来绘制三维曲面,但是它们有一些细微的区别。 “surf”函数将二维矩阵的值映射到三维曲面上。这个函数绘制出一个平滑的曲面,它的灰度值表示了该点的高度。这样,我们就可以通过调整曲面的颜色和高度来可视化一个三维数据集。 “mesh”函数也可以绘制三维曲面,但是它更注重于显示曲面的线框结构。它的输出结果是由曲面上的线条和网络点组成的一个三维网格,每个点的灰度值代表了这个点在二维矩阵中的值。 如果想要绘制二维数据的三维表面,首先需要创建一个二维数组,程序会将这个数组定义为一个矩阵。这样的一个矩阵可以是一个函数的输出,也可以是从一个文本文件或者Excel表格中导入的数据。一旦创建了这个矩阵,就可以使用“surf”或者“mesh”函数来创建三维图。 例如,我们想要创建一张三维图,其Z轴表示函数f(x,y)的值,我们可以将函数f定义为一个矩阵。然后,我们可以使用“surf”函数来画出这个矩阵三维图像。这个函数可以将矩阵的值映射到一个三维表面上,其中每个点的高度表示该点的值。这个操作可以用以下代码来实现: x = 0:0.1:10; y = 0:0.1:10; [X,Y] = meshgrid(x,y); Z = sin(X).*cos(Y); surf(X,Y,Z); 这段代码首先创建了两个从0到10的数组x和y,其间隔为0.1,用来构建一个网格。然后利用matlab中的meshgrid函数将这两个数组转成X、Y两个二维矩阵。接着,我们定义了一个Z矩阵,用来表示sin函数和cos函数的运算结果,并传递这个矩阵到surf函数作为参数。最终我们可以在画布上看到一个三维表面的图形。 总之,通过使用Matlab中强大的绘图功能,我们可以实现从二维矩阵三维图形的转化。这使我们能够更好地展示高维数据,从而更深刻地理解大型数据集的结构和关联。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值