pytorch入门(五):什么是广播运算?

1、概念

  广播运算初看有点神秘,其实它类似于分数加减法的通分。通分用于解决分母不同的分数加减运算问题,它利用分数基本性质把分母调整为原来分母的最小公倍数。广播用于解决不同形状数组的运算问题,包括加法、减法、点乘(对应元素相乘)、点除(对应元素相除乘)、比较等,如对列向量a和行向量b做加法运算,
a + b = [ 1 2 ] + [ 3 4 ] a+b= \left[ \begin{matrix}1\\2 \end{matrix} \right]+\left[\begin{matrix}3&4 \end{matrix} \right] a+b=[12]+[34]
由于二者形状不同,a+b在代数上是行不通的。但是python通过广播操作,把a、b的形状统一为  2 × 2 2\times2 2×2 的数组之后,就能够执行加法运算了。
a + b = [ 1 1 2 2 ] + [ 3 4 3 4 ] = [ 4 5 5 6 ] a+b=\left[\begin{matrix}1&1\\2&2 \end{matrix} \right] + \left[\begin{matrix}3&4\\3&4 \end{matrix} \right]=\left[\begin{matrix}4&5\\5&6 \end{matrix} \right] a+b=[1212]+[3344]=[4556]

a = torch.tensor([[1],[2]])
print(a)
b = torch.tensor([3,4])
print(b)
print(a+b)

结果

[[4 5]
 [5 6]]

广播运算有什么用?如果要从a+b中选择大于5的元素,一种写法是不用广播的,
a + b > [ 5 5 5 5 ] a+b>\left[\begin{matrix} 5&5\\ 5&5\\ \end{matrix}\right] a+b>[5555]
另一种是使用广播的(广播把自动5变为  2 × 2 2\times2 2×2 的矩阵),
a + b > 5 a+b>5 a+b>5
您认为哪种写法更加自然?

2、规则

  广播统一两个数组的维度,具体做法有两步,分别起到统一维度和统一长度的作用。

(1)将低维变为高维

  如  a a a 是二维的,
a = [ 1 2 3 4 5 6 ] a=\left[\begin{matrix} 1&2&3\\ 4&5&6 \end{matrix} \right] a=[142536]
b b b 是一维的,
b = [ 1 2 3 ] b=\left[\begin{matrix}1&2&3\end{matrix}\right] b=[123]
广播首先将 b b b 变为二维数组
b = [ [ 1 2 3 ] ] b=\left[\left[\begin{matrix} 1&2&3 \end{matrix} \right]\right] b=[[123]]
使得 a 、 b a、b ab都是二维数组,

(2)将1维变为n维

现在, a a a 形状为 2 × 3 2\times3 2×3 b b b 形状为 1 × 3 1\times3 1×3,二者形状不同,需要为 b b b 添加新行(复制第0行),把 b b b 的行数也变为2,使得各维的长度也相等。
这里解释一下维度的含义。在计算机中提到维度时,是指描述一个分量需要几个下标,前面(1)中的维度就是如此。在数学中提到维度时,指的是分量的个数,实际上也是各维的长度,如(2)中的维度用法。下面构造 a 、 b a、b ab两个数组如下,
a = [ 0 1 ] ,     b = [ 0 1 2 ] a=\left[\begin{matrix}0\\1\end{matrix}\right], \ \ \ b=\left[\begin{matrix}0&1&2\end{matrix}\right] a=[01],   b=[012]
观察广播的效果

a = torch.tensor([[0],[1]])
print(a)
b = torch.tensor([0,1,2])
print(b)
c = torch.broadcast_tensors(a, b)
print(c)

结果

tensor([[0],
        [1]])
tensor([0, 1, 2])
(tensor([[0, 0, 0],
         [1, 1, 1]]), 
 tensor([[0, 1, 2],
         [0, 1, 2]]))

结果 c c c 有两个数组,分别是对 a 、 b a、b ab 广播的结果。
急转弯,请写出下列代码的执行结果:

a = torch.tensor([[1,2]])
b = torch.tensor([[3],[4]])
c = a*b

分析:* 表示点乘运算,即相同形状的两个矩阵的对应元素相乘的运算。 a a a 的形状为 (1,2), b b b 的形状为 (2,1),
需要用广播统一二者的形状,使得二者的形状都是 (2,2),
a a a 要多一行, b b b 要多一列。通过拷贝, a a a 成为 [[1,2], [1,2]], b b b 成为 [[3,3], [4,4]],完成广播后,现在可以做对应相乘的操作了,结果为

tensor([[3, 6],
        [4, 8]])
3、应用

下面来作一道习题(选自numpy习题100道中的第93题)
93. Consider two arrays A and B of shape (8,3)
and (2,2). How to find rows of A that contain
elements of each row of B regardless of the
order of the elements in B? (★★★)
例如,A、B分别是如下数组

A = torch.Tensor([[4, 9, 3],
	 	          [0, 3, 9],
     		      [7, 3, 7],
        		  [3, 1, 6],
	        	  [6, 9, 8],
	    	      [6, 6, 8],
    	    	  [4, 3, 6],
	    	      [9, 1, 4]])
	        
B = torch.Tensor([[4, 1],
	       		  [9, 9]])

A数组中的第0行[4,9,3]含有B数组两行中的元素、第7行[9,1,4]也含有B数组两行中的元素,所以结果为[0, 7]。

torch.manual_seed(0)
A = torch.randint(10,size=(8,3)).reshape(8,3,1,1)
B = torch.randint(10,size=(2,2))
C = A==B
rows = torch.where(C.any(3).any(1).all(1))[0]
print(rows)

代码运行结果

tensor([0, 7])

代码中,
torch.manual_seed(0) 指定随机数的种子,用于再次产生相同的随机数。A、B是题目要求形状的数组,且A的形状被调整为四维数组。A==B要求按位置比较A、B的对应元素是否相等,由于A、B的形状不同,所以系统用广播将A、B的形状调整为(8,3,2,2)。A中的每个元素被重复4遍,形成(2,2)的数组,分别与B的四个元素相比较,比较结果是逻辑值存于C数组中。C.any(3).any(1).all(1)的理解稍有难度。
先考虑更直观的问题。问题1,如何计算B数组中有没有1?

print(B==1)
tensor([[False,  True],
        [False, False]])

只需要用any来检测所有元素中是否有True即可。

print((B==1).any())
tensor(True)

继续,问题2,如何计算B数组的两行中都有1吗?any(1)的参数1表示方向,与sum是一样的。

print((B==1).any(1))
tensor([ True, False])

Ture表示第0行有1,False表示第1行没有1,两行都有1要求两个结果都是True,用all做判断如下,

print((B==1).any(1).all())
tensor(False)

结果与事实相符。
问题3,如何计算三个数中既有B的0行中的元素,也有B的1行中的元素?考虑A[0]中的三个数4、9、3。

print((A[0]==B))
tensor([[[ True, False],			#B中有4吗?
         [False, False]],

        [[False, False],			#B中有9吗?
         [ True,  True]],

        [[False, False],			#B中有3吗?
         [False, False]]])
print((A[0]==B).any(2))

any(2)的参数是哪一维?这涉及到如何想象四维数组,后面会总结出来,这倒未必是紧要的知识点。any(2)就是上面的6横排,所以得到下面的6个判断结果(每排中有True,结果就为True)分别表示4是B的第0行,但不是第1行元素,9不是B的第0行,但是第1行元素,3不是B的第0行、也不是第1行元素。

tensor([[ True, False],
        [False,  True],
        [False, False]])

这个结果中,左边一列表示有没有B的0列的元素,右边一列表示有没有B的1列的元素,用ayn(0)判断如下

print((A[0]==B).any(2).any(0))
tensor([True, True])

上面的结果表明,4、9、3中有B的0行元素,也有1行元素,再加个all就判断出是否有两行元素(两个都是True)。

print((A[0]==B).any(2).any(0).all())
tensor(True)

问题4,8个元素又如何?

(A==B).any(3).any(1).all(1)

结果中的True代表A的某行满足要求,False则表示不满足要求。

tensor([ True, False, False, False, False, False, False,  True])

加个where是得出满足条件的行的下标,

torch.where(C.any(3).any(1).all(1))
(tensor([0, 7]),)

可见结果是个元组,再放上最后面的[0]就从元组中析出了结果。

  • 9
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值