Numpy 和 PyTorch 的广播机制

Numpy 广播规则:两个数组的形状即 shape,从后往前查看,要么严格相等,要么其中一个数组的在当前查看的维度上的长度为1,或者维度缺失,这样就能匹配,(并在相应维度上做拷贝扩充,但只是概念上假想的拷贝,而不是真正意义上内存上的拷贝),满足广播条件,否则不满足广播条件,程序报错。

广播(Broadcast)是 numpy 对不同形状(shape)的数组进行数值计算的方式, 对数组的算术运算通常在相应的元素上进行。

import numpy as np
a = np.array([[1,2,3,4],[5,6,7,8]])
b = np.array([[10,20,30,40],[50,60,70,80]])
print(a,a.shape)
print(b,b.shape)
print(a+b,(a+b).shape)

在这里插入图片描述
广播实列

import numpy as np 
 
a = np.array([[ 0, 0, 0],
           [10,10,10],
           [20,20,20],
           [30,30,30]])
b = np.array([1,2,3])
print(a,a.shape)
print(b,b.shape)
print(a+b,(a+b).shape)

在这里插入图片描述
在这里插入图片描述

广播的规则:

  • 让所有输入数组都向其中形状最长的数组看齐,形状中不足的部分都通过在前面加 1 补齐。
  • 输出数组的形状是输入数组形状的各个维度上的最大值。
  • 如果输入数组的某个维度和输出数组的对应维度的长度相同或者其长度为 1 时,这个数组能够用来计算,否则出错。
  • 当输入数组的某个维度的长度为 1 时,沿着此维度运算时都用此维度上的第一组值。

简单理解:对两个数组,分别比较他们的每一个维度(若其中一个数组没有当前维度则忽略),满足:

  • 数组拥有相同形状。
  • 当前维度的值相等。
  • 当前维度的值有一个是 1。

若条件不满足,抛出 “ValueError: frames are not aligned” 异常。

PyTorch 广播规则:两个张量的 shape 从后往前逐个检查,要么对应维度相等,要么其中一个的维度上长度是1,或者其中一个的维度缺失,则在相应维度上扩充,使得和另一个张量的对应维度上的长度相等。否则广播失败,程序报错。

在这里插入图片描述

import torch
x=torch.empty(5,1,4,1)
y=torch.empty(3,1,1)
print(x.shape)
print(y.shape)
print((x+y).shape)

在这里插入图片描述

import torch
x=torch.empty(5,1,4,1)
y=torch.empty(3,2,1)
print(x.shape)
print(y.shape)
print((x+y).shape)

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值