以前面我们定义过的Vector向量类为例,讨论重载运算符。
运算符重载基础
运算符重载本质上是函数调用。许多面向对象语言都支持运算符重载,如C++,如果使用得当,API会更好用,代码也会更加易于阅读。python也支持运算符重载,只是增加了一些限制:
* 不能重载内置类型的运算符
* 不能新建运算符,只能重载已有的
* 某些运算符不能重载 --is、and、or和not(位运算符&、^、|可以)
我们首先在Vector重载四个算术运算符:一元运算符-和+,中缀运算符+和*。
一元运算符
- (__neg__)
取负运算符,如果x是-2,那么-x是2。
+ (__pos__)
取正运算符,大部分情况都是 x == +x。
~ (__invert__)
对整数按位取反,定位为 ~x == -(x+1),如果x是-2,那么~x是1。
python把abs()函数也列为一元运算符,它对应的特殊方法是__abs__。实现运算符的重载也就是实现这些特殊方法即可,这些方法直接输self一个参数,代码如下:
1 from array import array 2 class Vector: 3 """ 4 自定义一个向量类 5 """ 6 typecode = 'i' 7 def __init__(self, items): 8 self._components = array(self.typecode, items) 9 10 def __iter__(self): 11 return (i for i in (self._components)) 12 13 def __str__(self): 14 return 'Vector' + str(tuple(self)) #已经实现了__iter__方法,self实例是可迭代对象 15 16 def __abs__(self): 17 return math.sqrt(sum(x*x for x in self)) 18 19 def __neg__(self): 20 return Vector(-x for x in self) 21 22 def __pos__(self): 23 return Vector(self) 24 25 def __invert__(self): 26 return Vector(~x for x in self) 27 28 v1 = Vector(range(10)) 29 print(v1) 30 print(abs(v1)) 31 32 v2 = -v1 33 v3 = ~v1 34 print(v2) 35 print(v3)
重载向量加法运算符+
两个向量相加得到一个新向量,我们不取严格意义的数学上的向量相加,我们运行两个维度不同的向量相加,得到的新向量维度取较大的维度。
重载加法运算符就要实现__add__特殊方法:
1 from array import array 2 from itertools import zip_longest 3 class Vector: 4 """ 5 自定义一个向量类 6 """ 7 typecode = 'i' 8 def __init__(self, items): 9 self._components = array(self.typecode, items) 10 11 def __iter__(self): 12 return (i for i in (self._components)) 13 14 def __str__(self): 15 return 'Vector' + str(tuple(self)) #已经实现了__iter__方法,self实例是可迭代对象 16 17 def __abs__(self): 18 return math.sqrt(sum(x*x for x in self)) 19 20 def __neg__(self): 21 return Vector(-x for x in self) 22 23 def __pos__(self): 24 return Vector(self) 25 26 def __invert__(self): 27 return Vector(~x for x in self) 28 29 def __add__(self, other): 30 """ 31 两个向量相加,向量维度取最长的一个 32 """ 33 return Vector(x+y for x, y in zip_longest(self, other, fillvalue=0)) 34 v1 = Vector(range(10)) 35 print(v1) 36 print(abs(v1)) 37 38 v2 = -v1 39 v3 = ~v1 40 print(v2) 41 print(v3) 42 43 v4 = v1 + v3 44 print(v4) 45 46 v5 = v1 + Vector([1,3,5,8,-6]) 47 print(v5)
zip会把两个序列类型的元素放在一起组成元组,但是回去序列个数少的为准,多余的元素就被丢弃;二itertools.zip_longest以元素个数多的为准,fillvalue可以指定缺少元素时使用的默认值。
ps:无论是一元运算符还是中缀运算符都不能修改原对象,都是返回一个新的对象。只有增量表达式可能会例外。
v1+v2等价于v1.__add__(v2),所以要求v1的类必须实现__add__方法,v2却不必。
1 v6 = v1 + (10,20,30) # OK v1.__add__((10,20,30)) 2 print(v6) 3 4 v7 = (10, 20, 30) + v1 # ERROR, tuple没有__add__方法 5 print(v7)
可以为Vector增加__radd__方法即可。
1 def __radd__(self, other): 2 return self+other
单就表达式a+b来说,解释器会执行如下操作:
1)、如果a有__add__方法,且返回值不是NotImplemented,那么调用a.__add__(b)返回结果
2)、如果a没有__add__方法,或者调用__add__方法返回NotImplemented,检查b有没有__radd__方法,如果有,且返回值不是NotImplemented,那么调用b.__radd__(a),返回结果
3)、如果b没有__radd__方法,或者__radd__方法返回NotImplemented,那么抛出TypeError异常
__radd__是__add__的反向版本,又叫__add__的右向特殊方法,因为他在右操作数上调用。
__rsub__和__sub__的关系也是一样。
不要把NotImplemented和NotImplementedError混淆。NotImplemented是个值,和None一样,可以作为返回值返回,而NotImplementedError是异常,使用raise抛出。
__radd__直接借助__add__实现是可以的,任何满足交换律的类型都可以这么做,但是向幂运算、除法运算、取模运算等不可以。
Vector类对象可以加可迭代对象,如果是不可迭代对象,就会抛出TypeError错误:
1 v1 + 3 2 """ 3 运行结果 4 File "E:\test.py", line 1168, in <module> 5 v1 + 3 6 File "E:\test.py", line 1142, in __add__ 7 return Vector(x+y for x, y in zip_longest(self, other, fillvalue=0)) 8 TypeError: zip_longest argument #2 must support iteration 9 """
如果是可迭代对象,但是类型是浮点型也会抛出TypeError错误:
1 v1 + (1.0, 2.0, 3.0) 2 """ 3 运行结果 4 File "E:\test.py", line 1118, in __init__ 5 self._components = array(self.typecode, items) 6 TypeError: integer argument expected, got float 7 """
但是这样的错误消息是由于类型问题导致的,为了遵守"鸭子类型"的精髓,我们不能测试other的类型或者其包含的元素的类型,而是要捕获TypeError异常,然后返回NotImplemented, 如果左操作数的__add__方法不能存在或者返回NotImplemented且右操作数的__radd__方法也不存在或者返回NotImplemented,那么python才会抛出TypError异常,返回一个错误提示如“unsupported operand type(s) for +: Vector and str”。
正确的实现如下:
1 def __add__(self, other): 2 """ 3 两个向量相加,向量维度取最长的一个 4 """ 5 try: 6 return Vector(x+y for x, y in zip_longest(self, other, fillvalue=0)) 7 except TypeError: 8 return NotImplemented 9 10 def __radd__(self, other): 11 return self+other
重载标量乘法运算符*
Vector([1,2,3])*x表示什么呢?如果x是数字,那么结果是个新的Vector实例,实例的每个分量都会乘上x:
1 v1 = Vector([1, 2, 3]) 2 v2 = v1*2 3 print(v2) #Vector(2, 4, 6)
如果x也是一个向量,即两个向量相乘,叫做两个向量的点积。如果把一个向量看成行向量,另一个当成列向量,那么就是向量乘法。NumPy等库目前的做法是,不重载这两种意义的*,只用*计算标量积,计算点积用numpy.dot()函数计算。
如果我们自己实现标量积就要实现__mul__和__rmul__,代码如下:
1 from array import array 2 from itertools import zip_longest 3 import numbers, fractions 4 class Vector: 5 """ 6 自定义一个向量类 7 """ 8 typecode = 'd' 9 def __init__(self, items): 10 self._components = array(self.typecode, items) 11 12 def __iter__(self): 13 return (i for i in (self._components)) 14 15 def __str__(self): 16 return 'Vector' + str(tuple(self)) #已经实现了__iter__方法,self实例是可迭代对象 17 18 def __abs__(self): 19 return math.sqrt(sum(x*x for x in self)) 20 21 def __neg__(self): 22 return Vector(-x for x in self) 23 24 def __pos__(self): 25 return Vector(self) 26 27 def __invert__(self): 28 return Vector(~x for x in self) 29 30 def __add__(self, other): 31 """ 32 两个向量相加,向量维度取最长的一个 33 """ 34 try: 35 return Vector(x+y for x, y in zip_longest(self, other, fillvalue=0)) 36 except TypeError: 37 return NotImplemented 38 def __radd__(self, other): 39 return self+other 40 41 def __mul__(self, scalar): 42 """ 43 计算标量积 44 """ 45 if not isinstance(scalar, numbers.Real): 46 return NotImplemented 47 return Vector(scalar*x for x in self) 48 49 def __rmul__(self, scalar): 50 return self*scalar 51 v1 = Vector(range(10)) 52 print(v1) 53 print('------------') 54 v2 = v1*3 55 print(v2) 56 57 v3 = 2*v2 58 print(v3) 59 60 v4 = False*v2 61 print(v4) 62 63 v5 = v2*fractions.Fraction(3, 4) 64 print(v5)
在这里使用了跟重载加法运算符时不一样的办法,使用“白鹅类型”,即用instance检查scalar的类型。
scalar允许是整型、浮点型、bool型、甚至是fractions.Fraction实例,但不能是复数。
除此之外,还有其他的运算符:
还有没有列出的比较运算符。(小于<, 小于等于<=等等)
ps:点积运算符是python3.5新增加的。在旧版本中并不支持。
比较运算符
python解释器对众多比较运算符(==、!=、>、<、>= 、<=)的处理跟上面的类似,不过在两个方面有重大区别。
* 正向和方向调用使用的是同一系列方法。例如,对==来说,正向和反向调用都是调用__eq__方法,只是把参数对调了;而正向的__gt__方法调用的是反向的__lt__方法,并把参数对调。
* 对==和!=来说,如果反向调用失败,python会比较对象的id,而不是抛出TypeError。
如果正向方法返回NotImplemented的话,调用反向方法。
现在我们来实现一下Vector类的==操作。
1 from array import array 2 from itertools import zip_longest 3 import numbers, fractions 4 class Vector: 5 """ 6 自定义一个向量类 7 """ 8 typecode = 'd' 9 def __init__(self, items): 10 self._components = array(self.typecode, items) 11 12 def __iter__(self): 13 return (i for i in (self._components)) 14 15 def __len__(self): 16 return len(self._components) 17 18 def __str__(self): 19 return 'Vector' + str(tuple(self)) #已经实现了__iter__方法,self实例是可迭代对象 20 21 def __abs__(self): 22 return math.sqrt(sum(x*x for x in self)) 23 24 def __neg__(self): 25 return Vector(-x for x in self) 26 27 def __pos__(self): 28 return Vector(self) 29 30 def __invert__(self): 31 return Vector(~x for x in self) 32 33 def __add__(self, other): 34 """ 35 两个向量相加,向量维度取最长的一个 36 """ 37 try: 38 return Vector(x+y for x, y in zip_longest(self, other, fillvalue=0)) 39 except TypeError: 40 return NotImplemented 41 def __radd__(self, other): 42 return self+other 43 44 def __mul__(self, scalar): 45 """ 46 计算标量积 47 """ 48 if not isinstance(scalar, numbers.Real): 49 return NotImplemented 50 return Vector(scalar*x for x in self) 51 52 def __rmul__(self, scalar): 53 return self*scalar 54 55 def __eq__(self, other): 56 return len(self) == len(other) and \ 57 all(x==y for x, y in zip(self, other)) 58 59 v1 = Vector(range(10)) 60 print(v1) 61 print(v1==[i for i in range(10)]) 62 print(v1==[i for i in range(8)]) 63 print(v1==Vector((1,2,3,4,5,6,7,8))) 64 print(v1==Vector((1,2,3,4,5,6,7,8,8))) 65 print((0,1,2,3,4,5,6,7,8,9) == v1)
对于向量类对象和元组的比较可能并不尽如人意,我们可以加点类型检查:
1 def __eq__(self, other): 2 if not isinstance(other, Vector): 3 return NotImplemented 4 return len(self) == len(other) and \ 5 all(x==y for x, y in zip(self, other))
!=运算符我们无需定义,python会自动把对==运算符运行的结果取反。
增量赋值运算符
Vector已经支持增量运算符了。
1 va = Vector([1.0, 2.0, 3.0]) 2 vb = va 3 print(id(vb), id(va)) #vb, vb引用同一个对象 4 va+=[7,8,9,10,11] 5 print(va) 6 print(id(vb), id(va)) #vb还是原对象,va是新创建的对象
如果一个类没有实现就地运算符,增量运算符只是语法糖:a += b的作用跟a = a+b是一样的,会创建一个新对象。如果实现了就地运算符方法,例如__iadd__,那么 a += b不会创建新对象。
1 def __iadd__(self, other): 2 self._components = array(self.typecode,map(lambda pair:sum(pair),zip_longest(self, other, fillvalue=0))) 3 return self 4 va = Vector([1.0, 2.0, 3.0]) 5 vb = va 6 print(id(vb), id(va)) #vb, vb引用同一个对象 7 va+=[7,8,9,10,11] 8 print(va) 9 print(id(vb), id(va)) #vb,va的id一样,证明并没有创建新对象
__iadd__方法最后一定要把对象自身返回,如果没有返回self的话,va += vb,那么va就变成了None,因为如果函数没有返回值,python默认返回None。
1 def __iadd__(self, other): 2 self._components = array(self.typecode, map(lambda pair:sum(pair), zip_longest(self, other, fillvalue=0))) 3 #return self 4 va = Vector([1.0, 2.0, 3.0]) 5 vb = va 6 print(id(vb), id(va)) #vb, vb引用同一个对象 7 va+=[7,8,9,10,11] 8 print(va) 9 print(id(vb), id(va)) 10 print(va) #None
总结:
* python支持对运算符重载,但不包括内置的类型的运算符以及is、and、or、not。
* 一元运算符和中缀运算符的重载、正向方法和反向方法
* 使用鸭子类型捕获TypeError异常,或者使用isinstance检查类型,isinstance不能用具体类,而应该使用抽象基类,因为后续用户自定义的类可以生命为抽象基类的子类或是注册为虚拟子类
* 比较运算符和增量运算符的重载