Pytorch中的广播机制(Broadcast)

1. 广播机制定义

如果一个PyTorch操作支持广播,则其Tensor参数可以自动扩展为相等大小(不需要复制数据)。通常情况下,小一点的数组会被 broadcast 到大一点的,这样才能保持大小一致。

2. 广播机制规则

2.1 如果遵守以下规则,则两个tensor是“可广播的”:

  • 每个tensor至少有一个维度
  • 遍历tensor所有维度时,从末尾开始遍历(从右往左开始遍历)(从后往前开始遍历),两个tensor存在下列情况:
    • tensor维度相等
    • tensor维度不等且其中一个维度为1
    • tensor维度不等且其中一个维度不存在

2.2 如果两个tensor是“可广播的”,则计算过程遵循下列规则:

  • 如果两个tensor的维度不同,则在维度较小的tensor的前面增加维度,使它们维度相等
  • 对于每个维度,计算结果的维度值取两个tensor中较大的那个值
  • 两个tensor扩展维度的过程是将数值进行复制

3.代码举例

3.1 相同维度,一定可以 broadcasting。

# 相同维度,一定可以 broadcasting
x=torch.ones(5,7,3)
y=torch.ones(5,7,3)
z = x+y
x.shape,y.shape,z.shape
输出结果如下:
(torch.Size([5, 7, 3]), torch.Size([5, 7, 3]), torch.Size([5, 7, 3]))

3.2 x和y不能被广播,因为x没有符合“至少有一个维度”,所以不可以broadcasting。

# x和y不能被广播,因为x没有符合“至少有一个维度”,所以不可以broadcasting
x=torch.ones((0,))
y=torch.ones(5,7,3)
z = x+y
x.shape,y.shape,z.shape

x,y不能进行广播
3.3 x 和 y 可以广播。

# x 和 y 可以广播
x=torch.ones(5,3,4,1)
y=torch.ones(  3,1,1)
z = x+y
x.shape,y.shape,z.shape
# 从尾部维度开始遍历
# 1st尾部维度: x和y相同,都为1。
# 2nd尾部维度: y为1,x为4,符合维度不等且其中一个维度为1,则广播为4。
# 3rd尾部维度: x和y相同,都为3。
# 4th尾部维度: y维度不存在,x为5,符合维度不等且其中一个维度不存在,则广播为5。
输出结果如下:
(torch.Size([5, 3, 4, 1]), torch.Size([3, 1, 1]), torch.Size([5, 3, 4, 1]))

3.4 x 和 y 不可以广播,因为倒数第三维度x为2,y为3,不符合维度不等且其中一个维度为1。

# x 和 y 不可以广播,因为倒数第三维度x为2,y为3,不符合维度不等且其中一个维度为1。
x=torch.ones(5,2,4,1)
y=torch.ones(  3,1,1)
z = x+y
x.shape,y.shape,z.shape

x,y不能进行广播
3.5 x 和 y 可以广播,在维度较小y前面增加维度,使它们维度相等,同时使他们维度大小相同。

# x 和 y 可以广播,在维度较小y前面增加维度,使它们维度相等。
x=torch.ones(5,2,4,1)
y=torch.ones(1,1)
z = x+y
x.shape,y.shape,z.shape
输出结果如下:
(torch.Size([5, 2, 4, 1]), torch.Size([1, 1]), torch.Size([5, 2, 4, 1]))

4. in - place 语义

in-place operation称为原地操作符,在pytorch中是指改变一个tensor的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值。在pytorch中经常加后缀“”来代表原地操作符,例:.add _()、.scatter(),in-place操作不允许tensor使用广播机制那样来改变张量形状维度大小,如下例子所示。

# x 和 y 不可以广播
x=torch.empty(1,3,1)
y=torch.empty(3,1,7)
z = x.add_(y)
x.shape,y.shape,z.shape

使用in-place原地操作符

Python中,广播机制是指针对两个不同形状的数组进行对应项的加、减、乘、除运算时,首先将数组调整为统一的形状,然后再进行运算。这种机制在Numpy、TensorFlow和PyTorch等库中都有应用。\[1\] 举个例子来阐述Python广播机制。如果有一个形状为(3,4,5)的三维数组A和一个形状为(4,5)的二维数组B,由于A和B的后缘维度都为(4,5),所以可以进行广播机制。同理,如果A为(3,4)的二维数组,B为(4,)的一维数组,它们的后缘维度都是4,所以也可以进行广播。另外,如果A为(4,5)的三维数组,B为(4,1)的二维数组,两者维度相同,但其中一个维度的其中一方为1,也可以进行广播。\[2\] 下面是一个验证广播机制的小程序: ```python import numpy as np a = np.array(\[\[1,2,3\],\[4,5,6\]\]) # 2*3 b = np.array(\[\[1\],\[3\]\]) # 2*1 c = a + b print(c) a = np.array(\[\[\[1,2\],\[2,3\],\[3,4\]\],\[\[2,3\],\[4,5\],\[7,8\]\]\]) # 2*3*2 b = np.array(\[\[6,6\],\[7,7\],\[8,8\]\]) # 3*2 c = a + b print(c) print(c.shape) ``` 参考链接:\[https://www.cnblogs.com/jiaxin359/p/9021726.html\](https://www.cnblogs.com/jiaxin359/p/9021726.html) \[2\] 需要注意的是,当两个数组的形状无法满足广播机制的条件时,会抛出ValueError异常。例如,如果数组a的形状为(3,3),数组b的形状为(2,3),那么它们无法进行广播运算,会抛出异常。\[3\] #### 引用[.reference_title] - *1* *3* [【Python学习记录】Numpy广播机制(broadcast)](https://blog.csdn.net/xxm524/article/details/128210631)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [python广播机制broadcasting)](https://blog.csdn.net/weixin_44319196/article/details/107871808)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值