流畅的Python(十三)-正确重载运算符

一、核心要义

1. 一元运算符重载

2.加法和乘法运算符重载

3.比较运算符重载

4.增量赋值运算符重载

二、代码示例

1、一元运算符重载

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/2/25 10:35
# @Author  : Maple
# @File    : 01-一元运算符.py
# @Software: PyCharm

"""
为第10章的Vector类重载运算符 +(__pos__,实现效果: x = +x) 和 - (__neg__,实现效果 x = -x)

"""
import functools
import math
import operator
import reprlib
from array import array


class Vector:

    typecode = 'd'
    shortcut_names = 'xyzt'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __repr__(self):
        # 返回的components是str类型
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

    def __str__(self):
        return str(tuple(self))

    def __iter__(self):
        return iter(self._components)

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + bytes(self._components))

    def __bool__(self):
        return bool(abs(self))

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self._components))

    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            return TypeError(msg.format(cls=cls))

    def __getattr__(self, item):
        # 只有当v.x实例不存在x属性时,才会调用getattr
        cls = type(self)
        if len(item) == 1:
            position = cls.shortcut_names.find(item)
            if 0 <= position < len(self._components):
                return self._components[position]

        msg = '{.__name__!r} object has not attribute {!r}'
        raise AttributeError(msg.format(cls, item))

    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            # 限制修改'xyzt'单字母属性值
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!r}'
            elif name.islower():
                # 限制修改单字母(a-z)的属性值
                error = "can't set attributes 'a' to 'z' in {cls_name}!r"
            else:
                error = ''

            if error:
                msg = error.format(cls_name=cls, attr_name=name)
                raise AttributeError(msg)

        # 允许修改名字为其它值的属性
        super().__setattr__(name, value)

    def __eq__(self, other):
        # 如果分量太多,下面这种方式效率太低
        # return tuple(self) == tuple(other)
        if len(self) != len(other):
            return False

        for x, y in zip(self, other):
            if x != y:
                return False
        return True

    def __hash__(self):
        # 生成一个迭代器
        hashes = (hash(x) for x in self._components)
        return functools.reduce(operator.xor, hashes, 0)
        # 等价于下面的写法
        # return functools.reduce(lambda x,y : x *y ,hashes,0)

    # - 运算符重载
    def __neg__(self):
        return Vector(-x for x in self)

    # + 运算符重载
    def __pos__(self):
        return  Vector(x for x in self)


if __name__ == '__main__':

    # 1. - 运算符重载测试
    v = Vector([1,2,3,4])
    print(-v) # (-1.0, -2.0, -3.0, -4.0)

    # 2. + 运算符重载测试
    print(v == +v ) # True

2、向量加法运算符重载

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/2/25 10:44
# @Author  : Maple
# @File    : 02-重载向量加法运算符.py
# @Software: PyCharm

"""
1.对于序列类型,默认的加法运算行为是,比如 [1,2,3] + [1,1,1] = [1,2,3,1,1,1]
2.但实际我们预期的结果是 [1,2,3] + [1,1,1] = [2,3,4]
"""

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/2/25 10:35
# @Author  : Maple
# @File    : 01-一元运算符.py
# @Software: PyCharm
import itertools

"""
为第10章的Vector类重载运算符 +(__pos__,实现效果: x = +x) 和 - (__neg__,实现效果 x = -x)

"""
import functools
import math
import operator
import reprlib
from array import array


class Vector:

    typecode = 'd'
    shortcut_names = 'xyzt'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __repr__(self):
        # 返回的components是str类型
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

    def __str__(self):
        return str(tuple(self))

    def __iter__(self):
        return iter(self._components)

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + bytes(self._components))

    def __bool__(self):
        return bool(abs(self))

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self._components))

    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            return TypeError(msg.format(cls=cls))

    def __getattr__(self, item):
        # 只有当v.x实例不存在x属性时,才会调用getattr
        cls = type(self)
        if len(item) == 1:
            position = cls.shortcut_names.find(item)
            if 0 <= position < len(self._components):
                return self._components[position]

        msg = '{.__name__!r} object has not attribute {!r}'
        raise AttributeError(msg.format(cls, item))

    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            # 限制修改'xyzt'单字母属性值
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!r}'
            elif name.islower():
                # 限制修改单字母(a-z)的属性值
                error = "can't set attributes 'a' to 'z' in {cls_name}!r"
            else:
                error = ''

            if error:
                msg = error.format(cls_name=cls, attr_name=name)
                raise AttributeError(msg)

        # 允许修改名字为其它值的属性
        super().__setattr__(name, value)

    def __eq__(self, other):
        # 如果分量太多,下面这种方式效率太低
        # return tuple(self) == tuple(other)
        if len(self) != len(other):
            return False

        for x, y in zip(self, other):
            if x != y:
                return False
        return True

    def __hash__(self):
        # 生成一个迭代器
        hashes = (hash(x) for x in self._components)
        return functools.reduce(operator.xor, hashes, 0)
        # 等价于下面的写法
        # return functools.reduce(lambda x,y : x *y ,hashes,0)

    # - 运算符重载
    def __neg__(self):
        return Vector(-x for x in self)

    # + 运算符重载
    def __pos__(self):
        return  Vector(x for x in self)

    def __add__(self,other):
        # 返回两个对象的配对元组迭代器
        try:
            pairs = itertools.zip_longest(self,other,fillvalue=0.0)
            return Vector(a + b for a,b in pairs)
        except TypeError:
            raise NotImplemented

    def __radd__(self, other):
        # 调用方式other.__add__(self)
        # 注意不要漏return
        return self + other


if __name__ == '__main__':

    # 1. - 运算符重载测试
    v = Vector([1,2,3,4])
    print(-v) # (-1.0, -2.0, -3.0, -4.0)

    # 2. + 运算符重载测试
    print(v == +v ) # True

    # 3. 加法运算重载测试
    """基本流程
    1. Vector对象 + 另外一个可迭代对象,正常返回结果
    2. 非Vector类型的可迭代对象  +  Vector对象,首先会调用add, 因为add方法要求第一个对象是Vector类型,所以会报错:can only concatenate tuple (not "Vector") to tuple
       抛出NotImplemented错误
    3. 程序会继续调用__radd__方法,之后就会调用Vector对象.__add__(other),然后正常返回结果
    """
    print(v + (1,2,3,4)) #(2.0, 4.0, 6.0, 8.0)

    print((1,2,3,4) + v) #(2.0, 4.0, 6.0, 8.0)

3、向量乘法运算符重载

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/2/25 11:05
# @Author  : Maple
# @File    : 03-重载标量乘法运算符.py
# @Software: PyCharm
import functools
import itertools
import numbers
from _ast import operator
from array import array


class Vector:

    typecode = 'd'
    shortcut_names = 'xyzt'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __repr__(self):
        # 返回的components是str类型
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

    def __str__(self):
        return str(tuple(self))

    def __iter__(self):
        return iter(self._components)

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + bytes(self._components))

    def __bool__(self):
        return bool(abs(self))

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self._components))

    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            return TypeError(msg.format(cls=cls))

    def __getattr__(self, item):
        # 只有当v.x实例不存在x属性时,才会调用getattr
        cls = type(self)
        if len(item) == 1:
            position = cls.shortcut_names.find(item)
            if 0 <= position < len(self._components):
                return self._components[position]

        msg = '{.__name__!r} object has not attribute {!r}'
        raise AttributeError(msg.format(cls, item))

    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            # 限制修改'xyzt'单字母属性值
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!r}'
            elif name.islower():
                # 限制修改单字母(a-z)的属性值
                error = "can't set attributes 'a' to 'z' in {cls_name}!r"
            else:
                error = ''

            if error:
                msg = error.format(cls_name=cls, attr_name=name)
                raise AttributeError(msg)

        # 允许修改名字为其它值的属性
        super().__setattr__(name, value)

    def __eq__(self, other):
        # 如果分量太多,下面这种方式效率太低
        # return tuple(self) == tuple(other)
        if len(self) != len(other):
            return False

        for x, y in zip(self, other):
            if x != y:
                return False
        return True

    def __hash__(self):
        # 生成一个迭代器
        hashes = (hash(x) for x in self._components)
        return functools.reduce(operator.xor, hashes, 0)
        # 等价于下面的写法
        # return functools.reduce(lambda x,y : x *y ,hashes,0)

    # - 运算符重载
    def __neg__(self):
        return Vector(-x for x in self)

    # + 运算符重载
    def __pos__(self):
        return  Vector(x for x in self)

    def __add__(self,other):
        # 返回两个对象的配对元组迭代器
        try:
            pairs = itertools.zip_longest(self,other,fillvalue=0.0)
            return Vector(a + b for a,b in pairs)
        except TypeError:
            raise NotImplemented

    def __radd__(self, other):
        # 调用方式other.__add__(self)
        # 注意不要漏return
        return self + other

    def __mul__(self, scalar):
        # scalar参数的值要是数字
        if isinstance(scalar,numbers.Real):
            return Vector(n * scalar for n in self)
        else:
            return NotImplemented

    def __rmul__(self, scalar):
        return self * scalar


if __name__ == '__main__':

    # 乘法运算符测试
    v = Vector([1,2,3])
    print(v * 2) # # (2.0, 4.0, 6
    print(v * True) # (1.0, 2.0, 3.0)

    from fractions import  Fraction
    print(v * Fraction(1,3)) # (0.3333333333333333, 0.6666666666666666, 1.0)

4、比较运算符重载

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/2/25 12:48
# @Author  : Maple
# @File    : 04-比较运算符.py
# @Software: PyCharm

import functools
import itertools
import math
import numbers
from _ast import operator
from array import array


class Vector:

    typecode = 'd'
    shortcut_names = 'xyzt'

    def __init__(self, components):
        self._components = array(self.typecode, components)

    def __repr__(self):
        # 返回的components是str类型
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)

    def __str__(self):
        return str(tuple(self))

    def __iter__(self):
        return iter(self._components)

    def __bytes__(self):
        return (bytes([ord(self.typecode)]) + bytes(self._components))

    def __bool__(self):
        return bool(abs(self))

    def __abs__(self):
        return math.sqrt(sum(x * x for x in self._components))

    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)

    def __len__(self):
        return len(self._components)

    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            return TypeError(msg.format(cls=cls))

    def __getattr__(self, item):
        # 只有当v.x实例不存在x属性时,才会调用getattr
        cls = type(self)
        if len(item) == 1:
            position = cls.shortcut_names.find(item)
            if 0 <= position < len(self._components):
                return self._components[position]

        msg = '{.__name__!r} object has not attribute {!r}'
        raise AttributeError(msg.format(cls, item))

    def __setattr__(self, name, value):
        cls = type(self)
        if len(name) == 1:
            # 限制修改'xyzt'单字母属性值
            if name in cls.shortcut_names:
                error = 'readonly attribute {attr_name!r}'
            elif name.islower():
                # 限制修改单字母(a-z)的属性值
                error = "can't set attributes 'a' to 'z' in {cls_name}!r"
            else:
                error = ''

            if error:
                msg = error.format(cls_name=cls, attr_name=name)
                raise AttributeError(msg)

        # 允许修改名字为其它值的属性
        super().__setattr__(name, value)

    def __eq__(self, other):
        # 下面这种写法存在一个问题,比如Vector([1,2,3]) 和(1,2,3)会被判断成相等,但大部分情况下,我们应该是预期不相等的结果
        # return tuple(self) == tuple(other)

        # 改写:另外一个对象必须是Vector对象,才可能相等
        if isinstance(other,Vector):
            return (len(self) == len(other) and all( x == y for x in self for y in other))
        else:
            return NotImplemented

    def __hash__(self):
        # 生成一个迭代器
        hashes = (hash(x) for x in self._components)
        return functools.reduce(operator.xor, hashes, 0)
        # 等价于下面的写法
        # return functools.reduce(lambda x,y : x *y ,hashes,0)

    # - 运算符重载
    def __neg__(self):
        return Vector(-x for x in self)

    # + 运算符重载
    def __pos__(self):
        return  Vector(x for x in self)

    def __add__(self,other):
        # 返回两个对象的配对元组迭代器
        try:
            pairs = itertools.zip_longest(self,other,fillvalue=0.0)
            return Vector(a + b for a,b in pairs)
        except TypeError:
            raise NotImplemented

    def __radd__(self, other):
        # 调用方式other.__add__(self)
        # 注意不要漏return
        return self + other

    def __mul__(self, scalar):
        # scalar参数的值要是数字
        if isinstance(scalar,numbers.Real):
            return Vector(n * scalar for n in self)
        else:
            return NotImplemented

    def __rmul__(self, scalar):
        return self * scalar


if __name__ == '__main__':

    v = Vector([1,2,3])
    print(v == (1,2,3)) # False

5、增量赋值运算符重载

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/2/25 12:54
# @Author  : Maple
# @File    : 05-增量赋值运算符.py
# @Software: PyCharm
import abc


class Tombola(abc.ABC):
    @abc.abstractmethod
    def load(self,iterable):
        """从可迭代对象中加载元素"""

    @abc.abstractmethod
    def pick(self):
        """随机删除元素,然后将其返回

        如果实例为空,这个方法应该抛出LookupError
        """
    def loaded(self):
        """如果至少有一个元素,则返回True,否则返回False"""
        return bool(self.inspect())


    def inspect(self):
        """返回一个有序元组,由当前元素构成"""
        items = []
        while True:
            try:
                items.append(self.pick())
            except LookupError:
                break
        self.load(items)
        return tuple(sorted(items))


class AddableBingoCage(Tombola):


    def __init__(self,items):
        self._items = list(items)

    def pick(self):
        return  self._items.pop()


    def __add__(self,other):
        # 另外一个对象必须是Tombola类型,才可以使用 + 运算符
        if isinstance(other,Tombola):
            return AddableBingoCage(self.inspect() + other.inspect())
        else:
            return NotImplemented

    def __iadd__(self, other):
        if isinstance(other,Tombola):
            other_iterable = other.inspect()
        else:
            try:
                other_iterable = iter(other)
            except TypeError:
                cls_name = type(self).__name__
                msg = "right operand in += must be {!r} or an iterable"
                raise TypeError(msg.format(cls_name))
        self.load(other_iterable)
        # iadd是就地改变对象,所以要返回改变之后的对象本身
        return self

    def load(self,iterable):
        for i in iterable:
            self._items.append(i)

    def __iter__(self):
        return (x for x in self._items)

    # def __str__(self):
    #     return str(tuple(self))

    def __repr__(self):
        return 'Vector{}'.format(tuple(self))


    def inspect(self):
        return self._items


if __name__ == '__main__':

     # 1.就地修改(iadd)测试
     ab = AddableBingoCage([1,2,3])
     ab += [4,5]
     print(ab) # Vector(1, 2, 3, 4, 5)

     # ab += 1 #TypeError: right operand in += must be 'AddableBingoCage' or an iterable

     # 2.相加(add)测试
     #ab_new = ab + [1,2] # TypeError: unsupported operand type(s) for +: 'AddableBingoCage' and 'list'

     ab2 = AddableBingoCage([10,11])
     ab_new = ab + ab2
     print(ab_new) # Vector(1, 2, 3, 4, 5, 10, 11)

  • 9
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值