paddle的广播
paddle里面的广播机制,判断两个tensor在运算时能不能进行广播的标准是:两个Tensor从后往前依次比较tensor的形状,如果出现以下三种情况,就都还能广播:
- 两个tensor当前比较维度维数相同
- 两个tensor有其中一个tensor的当前比较维度的维数为1
- 两个tensor有其中一个tensor的当前比较维度维数不存在
但是,但凡比较中出现的不是上面三种情况就g了,广播不了了,会报错# InvalidArgumentError: Broadcast dimension mismatch. 比如出现两个tensor当前比较维数一个是3,一个是4,就广播不了。
例子
举几个例子:
-
第一个例子
x = paddle.ones((2, 3, 6)) y = paddle.ones((2, 3, 6)) z = x + y print(z)
这个就明显可以广播,从后往前6和6比,3和3比……维数都一样,可以广播,结果为:
Tensor(shape=[2, 3, 6], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[[2., 2., 2., 2., 2., 2.], [2., 2., 2., 2., 2., 2.], [2., 2., 2., 2., 2., 2.]], [[2., 2., 2., 2., 2., 2.], [2., 2., 2., 2., 2., 2.], [2., 2., 2., 2., 2., 2.]]])
-
第二个例子
x = paddle.ones((2, 4, 6)) y = paddle.ones((2, 3, 6)) z = x + y print(z)
此时x和y是不可广播的,比较过程:
- 第一次6和6比没问题
- 第二次4和3比,不相等且不属于上面三种情况的任意一种,不能广播
- ……
-
第三个例子
x = paddle.ones((1, 2, 1, 6)) y = paddle.ones((1, 3, 6)) z = x + y print(z)
此时x和y可以广播,比较过程如下:
- 第一次6和6比没问题
- 第二次1和3比,因为有个1,所以虽然维数大小不同,但是只要对1那边的那个维度复制成三个就能广播
- 第三次2和1比,与上一个比较同理,有1就能广播
- x为1,y的维度不存在,同样可以广播
结果如下:
Tensor(shape=[1, 2, 3, 6], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[[[2., 2., 2., 2., 2., 2.], [2., 2., 2., 2., 2., 2.], [2., 2., 2., 2., 2., 2.]], [[2., 2., 2., 2., 2., 2.], [2., 2., 2., 2., 2., 2.], [2., 2., 2., 2., 2., 2.]]]])