numpy与tensorflow中的广播(broadcast)机制

项目github地址:bitcarmanlee easy-algorithm-interview-and-practice
欢迎大家star,留言,一起学习进步

1.numpy中的广播

广播(broadcast)是numpy中经常使用的一个技能点,他能够对不同形状的数组进行各种方式的计算。

举个简单的例子:

a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
c = a + b

此时c的结果为[5, 7, 9]。一般情况下要进行a+b的操作,需要a.shape=b.shape。比如上面的例子中,a、b的shape均为(3,),(3,)表示一维数组,数组中有3个元素。

很多情况下,维度不同的两个数组也能进行类似的操作,这里头的原理就是会自动出发广播机制。

a = np.array([[0, 0, 0],
              [1, 1, 1],
              [2, 2, 2],
              [3, 3, 3]])
b = np.array([1, 2, 3])
print(a.shape)
print(b.shape)
print(a + b)

结果为

(4, 3)
(3,)
[[1 2 3]
 [2 3 4]
 [3 4 5]
 [4 5 6]]
[[1 2 3]
 [2 3 4]
 [3 4 5]
 [4 5 6]]

a是一个4*3的二维数组,b是一个(3,)的一维数组。相当于将b在行的维度上扩充四倍进行运算。
此时a+b等价于a + np.tile(b, [4, 1])

如果是a + np.tile(b, [5, 1]),此时代码会报错,提示维度不匹配。

ValueError: operands could not be broadcast together with shapes (4,3) (5,3) 

2.tensorflow中的广播

tensorflow中的广播机制与numpy类似。
同样看一个例子

def broadcast():
    a = tf.constant([[1, 2], [3, 4]])
    b = tf.constant([[1], [2]])
    c = a + b
    d = a + tf.tile(b, [1, 2])

    with tf.Session() as sess:
        cc, dd = sess.run([c, d])
        print(cc)
        print(dd)

输出的结果为

[[2 3]
 [5 6]]
[[2 3]
 [5 6]]

可见此时a + ba + tf.tile(b, [1, 2])的效果是一样的。

3.广播的优点

广播机制允许我们在隐式情况下进行填充(tile),而这可以使得我们的代码更加简洁,并且更有效率地利用内存,因为我们不需要另外储存填充操作的结果。一个可以表现这个优势的应用场景就是在结合具有不同长度的特征向量的时候。为了拼接具有不同长度的特征向量,我们一般都先填充输入向量,拼接这个结果然后进行之后的一系列非线性操作等。这是一大类神经网络架构的共同套路(common pattern)。

a = tf.random_uniform([5, 3, 5])
b = tf.random_uniform([5, 1, 6])

# concat a and b and apply nonlinearity
tiled_b = tf.tile(b, [1, 3, 1])
c = tf.concat([a, tiled_b], 2)
d = tf.layers.dense(c, 10, activation=tf.nn.relu)

但是这个可以通过广播机制更有效地完成。我们利用事实f(m(x+y))=f(mx+my),简化我们的填充操作。因此,我们可以分离地进行这个线性操作,利用广播机制隐式地完成拼接操作。

pa = tf.layers.dense(a, 10, activation=None)
pb = tf.layers.dense(b, 10, activation=None)
d = tf.nn.relu(pa + pb)
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值