《Effective Python 2nd》——类与接口

引言

Python的类与继承机制,让我们很容易就能用对象表达出程序应有的行为,而且可以逐渐改进并扩展这些行为。

#37 用组合起来的类来实现多层结构,不要用嵌套的内置类型

Python内置的字典类型,很适合维护对象在生命期内的动态内部状态。所谓动态的,是指我们无法获知那套状态会用到哪些标识符。
例如,如果要用成绩册(Gradebook)记录学生的分数,而我们又没有办法提前确定这些学生的名字,那么受到记录的每位学生与各自的分数,对于Gradebook对象来说,就属于动态的内部状态。

为了实现这个需求,笔者定义了下面这样一个类。

class SimpleGradebook:
    def __init__(self):
        self._grades = {}

    def add_student(self, name):
        self._grades[name] = []

    def report_grade(self, name, score):
        self._grades[name].append(score)

    def average_grade(self, name):
        grades = self._grades[name]
        return sum(grades) / len(grades)

这个类用起来很简单。

book = SimpleGradebook()
book.add_student('Isaac Newton')
book.report_grade('Isaac Newton', 90)
book.report_grade('Isaac Newton', 95)
book.report_grade('Isaac Newton', 85)

print(book.average_grade('Isaac Newton'))

90.0

字典与相关的内置类型用起来很方便,但同时也容易遭到滥用导致代码出问题。例如,我们现在要扩展这个SimpleGradebook类的功能,让它按照科目保存成绩,而不是把所有科目的成绩存在一起。
通过修改_grades字典的用法,使它必须把键(学生的名字)与另一个字典相对应。那份小字典以各科的名称作键与一份列表对应起来,以保存学生在这一科的全部考试成绩。

from collections import defaultdict

class BySubjectGradebook:
    def __init__(self):
        self._grades = {}                       # 外面的字典

    def add_student(self, name):
        self._grades[name] = defaultdict(list)  # 里面的字典

程序写到现在,还算比较直观。但是接下来,我们要编写report_grade方法来记录某位学生在某科目上的一次成绩,并且要写average_grade方法计算某位学生的所有科目的平均成绩。这两个方法写起了就稍微有点复杂了。

from collections import defaultdict

class BySubjectGradebook:
    def __init__(self):
        self._grades = {}                       # 外面的字典

    def add_student(self, name):
        self._grades[name] = defaultdict(list)  # 里面的字典
        
    def report_grade(self, name, subject, grade):
        by_subject = self._grades[name]
        grade_list = by_subject[subject]
        grade_list.append(grade)

    def average_grade(self, name):
        by_subject = self._grades[name]
        total, count = 0, 0
        for grades in by_subject.values():
            total += sum(grades)
            count += len(grades)
        return total / count

扩展后的类,用起来依然比较容易。

book = BySubjectGradebook()
book.add_student('Albert Einstein')
book.report_grade('Albert Einstein', 'Math', 75)
book.report_grade('Albert Einstein', 'Math', 65)
book.report_grade('Albert Einstein', 'Gym', 90)
book.report_grade('Albert Einstein', 'Gym', 95)
print(book.average_grade('Albert Einstein'))
81.25

现在假设需求又变了,我们还要记录每次考试在科目里的权重。实现这项功能的一种办法就是改变里面那个小字典的用法,让它不要把成绩直接添加到与键名(科目名称)相对应的那份列表里,而是先用成绩与权重构成元组,然后把(score,weight)形式的元组添加到列表里。

class WeightedGradebook:
    def __init__(self):
        self._grades = {}

    def add_student(self, name):
        self._grades[name] = defaultdict(list)

    def report_grade(self, name, subject, score, weight):
        by_subject = self._grades[name]
        grade_list = by_subject[subject]
        grade_list.append((score, weight))

report_grade方法改起来似乎挺简单的,但是average_grade方法就比较难懂了。

class WeightedGradebook:
    def __init__(self):
        self._grades = {}

    def add_student(self, name):
        self._grades[name] = defaultdict(list)

    def report_grade(self, name, subject, score, weight):
        by_subject = self._grades[name]
        grade_list = by_subject[subject]
        grade_list.append((score, weight))
        
    def average_grade(self, name):
        by_subject = self._grades[name]

        score_sum, score_count = 0, 0
        for subject, scores in by_subject.items():
            subject_avg, total_weight = 0, 0
            for score, weight in scores:
                subject_avg += score * weight
                total_weight += weight

            score_sum += subject_avg / total_weight
            score_count += 1

        return score_sum / score_count

这个类用起来也变得困难了。

book = WeightedGradebook()
book.add_student('Albert Einstein')
book.report_grade('Albert Einstein', 'Math', 75, 0.05)
book.report_grade('Albert Einstein', 'Math', 65, 0.15)
book.report_grade('Albert Einstein', 'Math', 70, 0.80)
book.report_grade('Albert Einstein', 'Gym', 100, 0.40)
book.report_grade('Albert Einstein', 'Gym', 85, 0.60)
print(book.average_grade('Albert Einstein'))
80.25

如果遇到的是类似这种比较复杂的需求,那么不要再嵌套字典、元组、集合、列表等内置的类型了,而是应该编写一批新类并让这些类形成一套体系。

只要发现记录内部状态的代码开始变得复杂起来,就应该及时把这些代码拆分到多个类里。这样可以定义良好的接口,并且能够合理地封装数据。这种写法可以在接口与具体实现之间创建一层抽象。

把多层嵌套的内置类型重构为类体系

有很多种办法可以实现重构。我们这里采用的办法是,先从依赖树的最底层做起,也就是考虑怎么记录某科目的单次考试成绩与权重。我们以元组的形式来保存单次考试成绩,并且把这种元素添加到列表里面。

grades = []grades.append((95, 0.45))grades.append((85, 0.55))total = sum(score * weight for score, weight in grades)total_weight = sum(weight for _, weight in grades)average_grade = total / total_weightprint(average_grade)
89.5

汇总每次考试的权重之和total_weight时,不需要关注具体的成绩。
这种写法的问题在于,元组里的元素只能通过位置区分。例如,如果还要把老师的评语记在每次考试的成绩旁边,那么早前使用二元组的那些代码就全都需要修改,因为现在必须采用包含三个元素的元组才行。

grades = []grades.append((95, 0.45, 'Great job'))grades.append((85, 0.55, 'Better next time'))total = sum(score * weight for score, weight, _ in grades)total_weight = sum(weight for _, weight, _ in grades)average_grade = total / total_weightprint(average_grade)
89.5

元组拖得太长,就跟字典套得太深一样,都不好维护。所以只要发现元组里的元素超过两个,就应该考虑其他办法了。

Python有个命名元组(namedtuple)类型,恰好可以满足这样的需求,这种命名元组类型很容易定义出小型的表示不可变的数据。

from collections import namedtupleGrade = namedtuple('Grade', ('score', 'weight'))

这样的类,既可以通过位置参数构造,也可以用关键字参数来创建。每个属性都有名字,可以根据属性名称访问字段,如果将来需求发生变化,也很容易把这种结构改成普通的类。

namedtuple的局限

  1. 实例的属性值仍然可以通过数字下标与迭代来访问,如果有人通过这种方式访问这些属性,将来不太容易把它转换为普通的类。

有了叫作Grade的命名元组,我们就你可以写出表示科目的Subject类。

class Subject:
    def __init__(self):
        self._grades = []

    def report_grade(self, score, weight):
        self._grades.append(Grade(score, weight))

    def average_grade(self):
        total, total_weight = 0, 0
        for grade in self._grades:
            total += grade.score * grade.weight
            total_weight += grade.weight
        return total / total_weight

然后,就可以写一个表示学生的Student类,用它来记录某位学生各科目的考试成绩。

class Student:
    def __init__(self):
        self._subjects = defaultdict(Subject)

    def get_subject(self, name):
        return self._subjects[name]

    def average_grade(self):
        total, count = 0, 0
        for subject in self._subjects.values():
            total += subject.average_grade()
            count += 1
        return total / count

最后,写这样一个表示成绩册的Gradebook容器类,把每位学生的名字与表示这位学生的Student对象关联起来。

class Gradebook:
    def __init__(self):
        self._students = defaultdict(Student)

    def get_student(self, name):
        return self._students[name]

这些类所占的篇幅虽然比原来那种写法长了一倍,但理解起来却要容易得多。

book = Gradebook()
albert = book.get_student('Albert Einstein')
math = albert.get_subject('Math')
math.report_grade(75, 0.05)
math.report_grade(65, 0.15)
math.report_grade(70, 0.80)
gym = albert.get_subject('Gym')
gym.report_grade(100, 0.40)
gym.report_grade(85, 0.60)
print(albert.average_grade())
80.25

#38 让简单的接口接受函数,而不是类的实例

Python有许多内置的API,都允许我们传入某个函数来定制它的行为。这种函数可以叫作挂钩(hook),API在执行的过程中,会回调这些挂钩函数。

例如,list类型的sort方法就带有可选的key参数,如果指定了这个参数,那么它就会安装你提供的挂钩函数来决定列表中每个元素的先后顺序。

names = ['Socrates', 'Archimedes', 'Plato', 'Aristotle']
names.sort(key=len)
print(names)
['Plato', 'Socrates', 'Aristotle', 'Archimedes']

在Python中,许多挂钩都是无状态的函数,带有明确的参数与返回值。挂钩用函数来描述,要比定义成类更简单。用作挂钩的函数与别的函数一样,都是Python里的头等对象,即,这些函数与方法可以像Python中其他值那样传递与引用。

例如,我们要定制defaultdict类的行为。这种数据结构允许调用者提供一个函数,用来在键名缺失的情况下,创建与这个键相对应的值。只要字典发现访问的键不存在,就会触发这个函数,以返回应该与键相关联的默认值。
下面定义一个log_missing函数作为键名缺失时的挂钩,该函数总是会把这种键的默认值设为0。

def log_missing():    print('Key added')    return 0

下面这段代码通过定制的defaultdict字典,把increments列表里面描述的增量添加到current这个普通字典所提供的初始量上面。

from collections import defaultdictcurrent = {'green': 12, 'blue': 3}increments = [    ('red', 5),    ('blue', 17),    ('orange', 9),]result = defaultdict(log_missing, current)print('Before:', dict(result))# 有两个键没在字典中,因此触发了两次for key, amount in increments:    result[key] += amountprint('After: ', dict(result))
Before: {'green': 12, 'blue': 3}Key addedKey addedAfter:  {'green': 12, 'blue': 20, 'red': 5, 'orange': 9}

通过log_missing这样的挂钩函数,我们很容易构建出便于测试的API,这种API可以把挂钩所实现的附加效果与数据本身所应具备的确定行为分开。

例如,假设我们要在传给defaultdict的挂钩里面,统计它总共遇到了多少次键名缺失的情况。要实现这项功能,其中一个办法是采用有状态的闭包。下面就定义一个辅助函数,把missing闭包当作挂钩传给defaultdict字典,以便为缺失的键提供默认值。

def increment_with_report(current, increments):    added_count = 0    def missing():        nonlocal added_count  # 有状态的闭包        added_count += 1        return 0    result = defaultdict(missing, current)    for key, amount in increments:        result[key] += amount    return result, added_count

运行这个辅助函数处理前面的数据,可以得到预期的结果。统计键名缺失次数所用的added_count状态是由missing挂钩维护的,这体现了把简单函数传给接口的另一好处,也就是方便稍后添加新的功能,因为我们可以把实现这项功能所用的状态隐藏在这个简单的闭包里面。

result, count = increment_with_report(current, increments)assert count == 2

与无状态的闭包函数相比,用有状态的闭包作为挂钩写出来的代码会难懂一些。为了让代码更清晰,可以专门定义一个小类,把原本由闭包所维护的状态给封装起来。

class CountMissing:    def __init__(self):        self.added = 0    def missing(self):        self.added += 1        return 0

在Python中,方法与函数都是头等的对象,因此可以直接通过对象引用它所属的CountMissing类里的missing方法,并把这个方法传给defaultdict充当挂钩,让字典可以用这个挂钩制作默认值。在Python中,这种通过对象实例而引用的方法,很容易就能通过参数传给API当挂钩函数使用。

counter = CountMissing()result = defaultdict(counter.missing, current)  # 方法引用for key, amount in increments:    result[key] += amountassert counter.added == 2

为了让这个类的意义更加明确,可以给它定义名为__call__的特殊方法。这会让这个类的对象能够像函数那样得到调用。同时,也让内置的callable函数能够针对这种实例返回True值,用以表示这个实例与普通的函数或方法类似,都是可调用的。
凡是能够在后面加()执行的对象,都叫作callable

class BetterCountMissing:    def __init__(self):        self.added = 0    def __call__(self):        self.added += 1        return 0counter = BetterCountMissing()assert counter() == 0assert callable(counter)

下面,就用这样的BetterCountMissing实例给defaultdict当挂钩。

counter = BetterCountMissing()result = defaultdict(counter, current) # 基于__call__for key, amount in increments:    result[key] += amountassert counter.added == 2

如果某个类定义了__call__特殊方法,那么它的实例就可以像普通的Python函数那样调用。

#39 通过@classmethod多态来构造同一体系中的各类对象

在Python中,不仅对象支持多态,类也支持多态。

这里说的对象支持多态,可理解为在超类对象上面调用实例方法,实际触发的是子类对象的同名实例方法;
类支持多态,可理解为在超类上面调用类方法,实际触发的是子类的同名类方法。

多态机制使同一体系中的多个类可以按照各自独有的方式来实现同一个方法,这意味着这些类都可以满足同一套接口,或者都可以当作某个抽象类来使用,同时,它们又能在这个前提下,实现各自的功能。

例如,要实现一套MapReduce(映射-归纳/映射-化简)流程,并且以一个通用的类来表示输入数据。于是,我们定义这样一个InputData类,并把read方法留给子类去实现。

class InputData:    def read(self):        raise NotImplementedError

然后,编写一个具体的InputData子类,例如,可以从磁盘文件中读取数据的PathInputData类。

class PathInputData(InputData):    def __init__(self, path):        super().__init__()        self.path = path    def read(self):        with open(self.path) as f:            return f.read()

通用的InputData类以后可能会有很多个像PathInputData这样的子类,每个子类都会实现标准的read接口,并按照各自的方式把需要处理的数据读取过来。

除了输入数据要通用,我们还想让处理MapReduce任务的工作节点也能有一套通用的抽象接口,这样不同的Worker就可以通过这套标准的接口来消耗输入数据。

class Worker:    def __init__(self, input_data):        self.input_data = input_data        self.result = None    def map(self):        raise NotImplementedError    def reduce(self, other):        raise NotImplementedError

下面,我们定义一种具体的Worker子类,使它按照特定的方式实现MapReduce。这里统计每份数据里的换行符个数,然后把所有的统计值汇总起来。

class LineCountWorker(Worker):    def map(self):        data = self.input_data.read()        self.result = data.count('\n')    def reduce(self, other):        self.result += other.result

这样似乎不错,但是如何把这些组件拼接起来。输入数据与工作节点都有各自的类体系,而且这两套体系也抽象出了合理的接口。
然后,它们都必须落实到具体的对象上面,只有构造除了具体对象,才能写出有用的程序。
最简单的办法,是编写几个辅助函数,手动构建这些对象,并把它们连接起来。例如,可以采用下面的辅助函数读取目录中的内容,并给目录下每份文件构造一个PathInputData实例。

import os

def generate_inputs(data_dir):
    for name in os.listdir(data_dir):
        yield PathInputData(os.path.join(data_dir, name))

接下来,再编写一个辅助函数,针对generate_inputs返回的每个InputData实例分别创建相应的LineCountWorker对象。

def create_workers(input_list):
    workers = []
    for input_data in input_list:
        workers.append(LineCountWorker(input_data))
    return workers

然后,将这些Worker实例的映射(map)工作分发到多个线程中去执行。反复调用reduce,把这些Worker计算出的结果合并成一个值。

from threading import Thread

def execute(workers):
    threads = [Thread(target=w.map) for w in workers]
    for thread in threads: thread.start()
    for thread in threads: thread.join()

    first, *rest = workers
    for worker in rest:
        first.reduce(worker)
    return first.result

最后,编写一个函数,将刚才那三个环节串起来。

def mapreduce(data_dir):
    inputs = generate_inputs(data_dir)
    workers = create_workers(inputs)
    return execute(workers)

可以看到,该函数可以很好地处理随机制造出的这批输入文件。

import os
import random

def write_test_files(tmpdir):
    os.makedirs(tmpdir)
    for i in range(100):
        with open(os.path.join(tmpdir, str(i)), 'w') as f:
            f.write('\n' * random.randint(0, 100))

tmpdir = 'test_inputs'
write_test_files(tmpdir)

result = mapreduce(tmpdir)
print(f'There are {result} lines')
There are 4996 lines

然后这样做有个大问题,就是mapreduce函数根本不通用。假如要使用其他的InputDataWorker子类,那就必须修改generate_inputscreate_workersmapreduce代码。

这个问题的根本原因在于,构造对象的办法不够通用。Python中最好能够通过类方法多态(class method polymorphism)来解决。这种多态与InputData.read所体现的实例方法多态(instance method polymorphism)很像,只不过它针对的是类,而不是这些类的对象。

我们现在运用方法多态来实现MapReduce流程所用到的这些类。首先改写InputData类,把generate_inputs方法放到该类里面并声明成通用的@classmethod,这样它所欲子类都可以通过同一个接口来新建具体的InputData实例。

class GenericInputData:
    def read(self):
        raise NotImplementedError

    @classmethod
    def generate_inputs(cls, config):
        raise NotImplementedError

新的generate_inputs方法带有一个叫作config的字典参数,调用者可以把一系列匹配信息放到字典中,让具体的GenericInputData子类去解读。例如,PathInputData这个子类就会通过data_dir键从字典里寻找含有输入文件的那个目录。

class PathInputData(GenericInputData):    def __init__(self, path):        super().__init__()        self.path = path    def read(self):        with open(self.path) as f:            return f.read()    @classmethod    def generate_inputs(cls, config):        data_dir = config['data_dir']        for name in os.listdir(data_dir):            yield cls(os.path.join(data_dir, name))

然后,可以用类似的思路改写前面的Worker类。把叫作create_workers的辅助方法移动到这个类里面,也声明为@classmethod
新方法的input_class参数将会是GenericInputData的某个子类,我们要通过这个参数触发那个子类的generate_inputs方法,以创建出Worker所需的输入信息。
然后通过cls(input_data)这个通用的形式来调用构造函数,这样创建的实例,其类型是cls所表示的具体GenericWorker子类。

class GenericWorker:    def __init__(self, input_data):        self.input_data = input_data        self.result = None    def map(self):        raise NotImplementedError    def reduce(self, other):        raise NotImplementedError    @classmethod    def create_workers(cls, input_class, config):        workers = []        for input_data in input_class.generate_inputs(config):            workers.append(cls(input_data))        return workers

上面的代码创建输入信息时,用的是input_class.generate_inputs这样的写法,这么写正是为了触发类多态机制,以便将generate_inputs派发到input_class所表示的那个实际子类上面。
还要注意的是,在构造GenericWorker的子类对象时,用的是cls(...)这样的通用写法,而没有直接调用__init__方法。

接下来要修改具体的Worker类。

class LineCountWorker(GenericWorker):    def map(self):        data = self.input_data.read()        self.result = data.count('\n')    def reduce(self, other):        self.result += other.result

最后,重写mapreduce函数,让它通过worker_class.create_workers来创建工作节点,这样它就变得通用了。

def mapreduce(worker_class, input_class, config):    workers = worker_class.create_workers(input_class, config)    return execute(workers)

这次调用mapreduce时,必须多传几个参数,因为现在是通用的函数,必须把实际的输入数据与实际的工作节点告诉它。

config = {'data_dir': tmpdir}result = mapreduce(LineCountWorker, PathInputData, config)print(f'There are {result} lines')
There are 4996 lines

这样我们能随意编写其他的GenericInputDataGenericWorker子类,而不用再花时间去调整它们之间的拼接代码。

#40 通过super初始化超类

以前有种简单的写法,能在子类里面执行超类的初始化逻辑,那就是直接在超类名称上调用__init__方法并把子类实例传进去。

class MyBaseClass:    def __init__(self, value):        self.value = valueclass MyChildClass(MyBaseClass):    def __init__(self):        MyBaseClass.__init__(self, 5)    def times_two(self):        return self.value * 2foo = MyChildClass()assert foo.times_two() == 10

这个办法可以应对比较简单的类体系,但是在其他的情况下容易出现问题。

如果某个类继承了多个超类,那么直接调用超类的__init__方法会让代码产生误会。

直接调用__init__方法所产生的第一个问题在于,超类的构造逻辑不一定会按照它们在子类class语句中的声明顺序执行。例如,在MyBaseClass之外再定义两个类,让它们也分别去操纵本实例的value字段。

class TimesTwo:    def __init__(self):        self.value *= 2class PlusFive:    def __init__(self):        self.value += 5

下面这子类继承了刚才那三个类,而且它在class语句里指定的超类顺序与它执行那些超类的__init__时所用的顺序一致。

class OneWay(MyBaseClass, TimesTwo, PlusFive):    def __init__(self, value):        MyBaseClass.__init__(self, value)        TimesTwo.__init__(self)        PlusFive.__init__(self)

这样写,程序会按正常顺序初始化那几个超类。

foo = OneWay(5)print('First ordering value is (5 * 2) + 5 =', foo.value)
First ordering value is (5 * 2) + 5 = 15

但如果子类在class语句里指定的超类顺序,与它执行那些超类的__init__时的顺序不同,那么运行结果就会让人困惑。

class AnotherWay(MyBaseClass, PlusFive, TimesTwo):    def __init__(self, value):        MyBaseClass.__init__(self, value)        TimesTwo.__init__(self) # 与继承的顺序不同        PlusFive.__init__(self)

这样写,会使代码很难理解。实际上,程序依照的是__init__的调用顺序,而不是class语句中的声明顺序。

bar = AnotherWay(5)print('Second ordering value is', bar.value) # 5 x 2 + 5
Second ordering value is 15

直接调用__init__所产生的第二个问题在于,无法正确处理菱形问题。这种继承指的是子类通过类体系里两条不同路径的类继承了同一个超类。
例如,下面先从MyBaseClass派生出两个子类。

class TimesSeven(MyBaseClass):    def __init__(self, value):        MyBaseClass.__init__(self, value)        self.value *= 7class PlusNine(MyBaseClass):    def __init__(self, value):        MyBaseClass.__init__(self, value)        self.value += 9

然后,定义最终的子类,让它分别继承上面两个类,这样MyBaseClass就会出现在菱形体系的顶端。

class ThisWay(TimesSeven, PlusNine):    def __init__(self, value):        TimesSeven.__init__(self, value)        PlusNine.__init__(self, value)foo = ThisWay(5)print('Should be (5 * 7) + 9 = 44 but is', foo.value)
Should be (5 * 7) + 9 = 44 but is 14

ThisWay调用第二个超类的__init__时,那个方法会再度触发MyBaseClass__init__,导致self.value重置为5。所以,最终的结果为5 + 9 = 14

为了解决这些问题,Python内置了super函数并且规定了标准的方法解析顺序(method resolution order,MRO)。super能够确保菱形继承体系中的共同超类只初始化一次。MRO可以确定超类之间的初始化顺序。

下面再穿件一套菱形的类体系,但是这次,我们改用super()来调用超类的初始化逻辑。

class MyBaseClass:    def __init__(self, value):        self.value = valueclass TimesSevenCorrect(MyBaseClass):    def __init__(self, value):        super().__init__(value)        self.value *= 7class PlusNineCorrect(MyBaseClass):    def __init__(self, value):        super().__init__(value)        self.value += 9

位于菱形结构顶端的MyBaseClass会率先初始化,而且只会初始化一次。接下来,程序会参照菱形底端那个子类在class语句里声明超类时的顺序,来执行菱形结构中部的那两个超类。

class GoodWay(TimesSevenCorrect, PlusNineCorrect):    def __init__(self, value):        super().__init__(value)foo = GoodWay(5)print('Should be 7 * (5 + 9) = 98 and is', foo.value)
Should be 7 * (5 + 9) = 98 and is 98

这个执行顺序,似乎与看上去的想法。实际上,这两个超类之间的初始化顺序,要由子类的MRO确定,它可以通过mro方法来查询。

mro_str = '\n'.join(repr(cls) for cls in GoodWay.mro())print(mro_str)
<class '__main__.GoodWay'><class '__main__.TimesSevenCorrect'><class '__main__.PlusNineCorrect'><class '__main__.MyBaseClass'><class 'object'>

调用GoodWay(5)时,会先触发TimesSevenCorrect.__init__,进而触发PlusNineCorrect.__init__,而这又会触发MyBaseClass.__init__。程序到达菱形结果的顶端后,开始执行MyBaseClass的初始化逻辑,然后按照与刚才相反的顺序,依次执行PlusNineCorrectTimesSevenCorrectGoodWay的初始化逻辑。

所以,程序先在MyBaseClass中把value设成5,然后在PlusNineCorrect中给它加9,得到14,接着又在TimesSevenCorrect将它乘7,得到98

除了可以应对菱形继承结构,通过super()调用__init__,可使代码更容易维护。

super函数也可以用双参数的形式调用,第一个参数表示从这个类型开始按照方法解析顺序MRO向上搜索,而解析顺序则要由第二个参数所在类型的__mro__决定。

例如,按照下面这种写法,如果在super所返回的内容上调用__init__方法,那么程序会从ExplicitTrisect类型开始(不含该类型本身)按照MRO向上搜索,直到找到这样的__init__方法为止,而解析顺序是由第二个参数所属的类型ExplicitTrisect决定的,所以解析顺序是ExplicitTrisect -> MyBaseClass -> object

class ExplicitTrisect(MyBaseClass):    def __init__(self, value):        super(ExplicitTrisect, self).__init__(value)        self.value /= 3assert ExplicitTrisect(9).value == 3

一般来说,在类的__init__方法里面通过super初始化实例时,不需要采用双参数的形式,而是可以直接采用不带参数的写法调用super,这样Python编译器会自动将__class__self当成参数传递进去。所以,下面这两种写法跟刚才那种写法是同一个意思。

class AutomaticTrisect(MyBaseClass):    def __init__(self, value):        super(__class__, self).__init__(value)        self.value /= 3class ImplicitTrisect(MyBaseClass):    def __init__(self, value):        super().__init__(value)        self.value /= 3assert ExplicitTrisect(9).value == 3assert AutomaticTrisect(9).value == 3assert ImplicitTrisect(9).value == 3

只有一种情况需要明确给super指定参数,就是:我们想从子类里面访问超类对某项功能所做的实现方案,而那种方案可能已经被子类覆盖掉了。

#41 考虑用mix-in类来表示可组合的功能

基本原则是尽量少用多重继承。

如果既要通过多重继承来方便地封装逻辑,又想避开可能出现的问题,那么就应该把有待继承的类写成mix-in类。这种类只提供一小套方法给子类去沿用,而不定义自己实例级别的属性,也不需要__init__构造函数。

在Python里面很容易编写mix-in,因为无论对象是什么类型,我们都可以方便地检视它当前的状态。这种动态监测机制,让我们只需要把通用的功能在mix-in实现一遍即可,将来也可以把这项功能应用到其他许多类里面。

例如,现在要实现这样一个功能,把内存中的Python对象表示成字典形式以便做序列化处理。不妨将这项功能写为通用代码,以供其他类使用。

为了演示这种做法,我们定义下面这个mix-in,让它提供名为to_dict的public方法。凡是想支持这项功能的类,都可以从mix-in继承。

class ToDictMixin:    def to_dict(self):        return self._traverse_dict(self.__dict__)

具体的实现代码写得很直观,我们可以通过instance函数动态地检视值的类型,并利用hasattr函数判断值里面有没有叫作__dict__的字典。

   
class ToDictMixin:
    def to_dict(self):
        return self._traverse_dict(self.__dict__)
    
    def _traverse_dict(self, instance_dict):
        output = {}
        for key, value in instance_dict.items():
            output[key] = self._traverse(key, value)
        return output

    def _traverse(self, key, value):
        if isinstance(value, ToDictMixin):
            return value.to_dict()
        elif isinstance(value, dict):
            return self._traverse_dict(value)
        elif isinstance(value, list):
            return [self._traverse(key, i) for i in value]
        elif hasattr(value, '__dict__'):
            return self._traverse_dict(value.__dict__)
        else:
            return value

下面以二叉树为例,演示如何表示二叉树的BinaryTree类具备刚才那个mix-in所提供的的功能。

class BinaryTree(ToDictMixin):
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right

定义了这样的BinaryTree类猴,很容易就能把二叉树里面那些相互关联的Python对象转换成字典的形式。

from pprint import pprint
tree = BinaryTree(10,
    left=BinaryTree(7, right=BinaryTree(9)),
    right=BinaryTree(13, left=BinaryTree(11)))

pprint(tree.to_dict())

{'left': {'left': None,
          'right': {'left': None, 'right': None, 'value': 9},
          'value': 7},
 'right': {'left': {'left': None, 'right': None, 'value': 11},
           'right': None,
           'value': 13},
 'value': 10}

mix-in最妙的地方在于,子类既可以沿用它所提供的的功能,又可以对其中一些地方做自己的处理。

例如,我们从BinaryTree派生了一个子类,让这种特殊的BinaryTreeWithParent二叉树能够把指向上级节点的引用保存下来。
但问题是,这种二叉树的to_dict方法是从ToDictMixin继承来的,它所触发的_traverse方法,在面对循环引用时,会无休止地递归下去。

class BinaryTreeWithParent(BinaryTree):
    def __init__(self, value, left=None,
                 right=None, parent=None):
        super().__init__(value, left=left, right=right)
        self.parent = parent

为了避免无限循环,我们可以覆盖BinaryTreeWithParent._traverse方法,让它对指向上级节点的引用做专门处理,而对于其他的值,则继续沿用从mix-in继承的_traverse逻辑。
下面这段代码,首先判断当前值是不是指向上级节点的引用。如果是,就直接返回上级节点的value值;如果不是,那就通过内置的super函数沿用由mix-in超类所给出默认实现方案。

class BinaryTreeWithParent(BinaryTree):
    def __init__(self, value, left=None,
                 right=None, parent=None):
        super().__init__(value, left=left, right=right)
        self.parent = parent
        
    def _traverse(self, key, value):
        if (isinstance(value, BinaryTreeWithParent) and
                key == 'parent'):
            return value.value  # 防止循环
        else:
            return super()._traverse(key, value)

现在调用BinaryTreeWithParent.to_dict就没有问题了。

root = BinaryTreeWithParent(10)
root.left = BinaryTreeWithParent(7, parent=root)
root.left.right = BinaryTreeWithParent(9, parent=root.left)
pprint(root.to_dict())
{'left': {'left': None,
          'parent': 10,
          'right': {'left': None, 'parent': 7, 'right': None, 'value': 9},
          'value': 7},
 'parent': None,
 'right': None,
 'value': 10}

只要BinaryTreeWithParent._traverse没问题,带有BinaryTreeWithParent属性的其他类就可以直接继承ToDictMixin,这样的话,程序在把这种对象转换成字典时,会自动对其中的BinaryTreeWithParent属性做出正确处理。

class NamedSubTree(ToDictMixin):
    def __init__(self, name, tree_with_parent):
        self.name = name
        self.tree_with_parent = tree_with_parent

my_tree = NamedSubTree('foobar', root.left.right)
pprint(my_tree.to_dict())  
{'name': 'foobar',
 'tree_with_parent': {'left': None, 'parent': 7, 'right': None, 'value': 9}}

多个mix-in可以组合起来用。例如,我们要再写一个mix-in,让所有的类都可以通过继承它来实现JSON序列化功能。
在编写这个mix-in时,假设继承了它的那个类肯定有自己的to_dict方法。

import json

class JsonMixin:
    @classmethod
    def from_json(cls, data):
        kwargs = json.loads(data)
        return cls(**kwargs)

    def to_json(self):
        return json.dumps(self.to_dict())

JsonMixin既定义了实例方法,也定义了类方法。于是,继承了这个mix-in的其他类也会拥有这两种行为。
在本例中,继承JsonMixin的类只需要提供to_dict方法以及能接受关键字参数的__init__方法即可。

有了这样两个mix-in,我们很容易就能创建一套含有工具类的体系,让其中的各种类型都可以把对象序列化成JSON格式并且能够根据JSON格式的数据创建这样的对象。
而这只需要开发者按照固定的样式多写一点代码即可。

例如,可以用这样一套由数据类所构成的体系表示数据中心的各种设备与它们之间的结构关系。

class DatacenterRack(ToDictMixin, JsonMixin):
    def __init__(self, switch=None, machines=None):
        self.switch = Switch(**switch)
        self.machines = [
            Machine(**kwargs) for kwargs in machines]

class Switch(ToDictMixin, JsonMixin):
    def __init__(self, ports=None, speed=None):
        self.ports = ports
        self.speed = speed

class Machine(ToDictMixin, JsonMixin):
    def __init__(self, cores=None, ram=None, disk=None):
        self.cores = cores
        self.ram = ram
        self.disk = disk

这样,我们很容易就能根据JSON格式的信息把这些对象还原出来,另外,也可以把它们再序列化成JSON格式。

serialized = """{    "switch": {"ports": 5, "speed": 1e9},    "machines": [        {"cores": 8, "ram": 32e9, "disk": 5e12},        {"cores": 4, "ram": 16e9, "disk": 1e12},        {"cores": 2, "ram": 4e9, "disk": 500e9}    ]}"""# 反序列化成DatacenterRack对象deserialized = DatacenterRack.from_json(serialized)# 序列化成JSON格式roundtrip = deserialized.to_json()assert json.loads(serialized) == json.loads(roundtrip)

对于JsonMixin这样的mix-in来说,即便直接继承它的那个类还通过类体系中的其他更高层类型间接地继承了它,程序依然能正常运行,因为Python可以把相关的方法正确地派发给JsonMixin类。

#42 优先考虑用public属性表示应受保护的数据,不要用private属性表示

Python类的属性只有两种访问级别,publicprivate

class MyObject:    def __init__(self):        self.public_field = 5        self.__private_field = 10        def get_private_field(self):        return self.__private_field

public属性能公开访问

foo = MyObject()assert foo.public_field == 5

如果属性名以两个下划线开头,那么即为private字段。

assert foo.get_private_field() == 10

但如果直接访问private字段,那就会抛出异常。

foo.__private_field
---------------------------------------------------------------------------AttributeError                            Traceback (most recent call last)<ipython-input-5-a888a87e4048> in <module>----> 1 foo.__private_field


AttributeError: 'MyObject' object has no attribute '__private_field'

类方法可以访问本类的private属性,因为类方法也是在这个类的范围里声明的。

class MyOtherObject:    def __init__(self):        self.__private_field = 71        @classmethod    def get_private_field_of_instance(cls, instance):        return instance.__private_field    bar = MyOtherObject()assert bar.get_private_field_of_instance(bar) == 71

private字段也只能给这个类自己用,子类不能访问超类的private字段。

class MyParentObject:    def __init__(self):        self.__private_field = 71    class MyChildObject(MyParentObject):    def get_private_field(self):        return self.__private_fieldbaz = MyChildObject()baz.get_private_field()
---------------------------------------------------------------------------AttributeError                            Traceback (most recent call last)<ipython-input-8-2c64f8b0f537> in <module>      8       9 baz = MyChildObject()---> 10 baz.get_private_field()


<ipython-input-8-2c64f8b0f537> in get_private_field(self)      5 class MyChildObject(MyParentObject):      6     def get_private_field(self):----> 7         return self.__private_field      8       9 baz = MyChildObject()


AttributeError: 'MyChildObject' object has no attribute '_MyChildObject__private_field'

这种防止其他类访问private属性的功能,其实仅仅是通过变换属性名称而实现的。比如MyChildObject.get_private_field这样的方法想要访问__private_field属性时,它会把下划线和类名加载这个属性名称的前面,所以代码上实际上访问的是_MyChildObject__private_field。了解了这种规则,我们就可以从任何一个类里面访问private属性。

assert baz._MyParentObject__private_field == 71

查看该对象的属性字典,也可以发现private属性的名称其实是变换后的名称存储的。

print(bar.__dict__)
{'_MyOtherObject__private_field': 71}

为了减少在不知情情况下访问内部数据而造成的损伤,Python开发者会按照风格指南里面建议的方式来给字段命名。以单下划线开头的字段(如_protected_field),习惯上叫作受保护(protected)的字段。

只有一种情况是可以考虑用private属性解决的,就是子类属性有可能与超类重名的清下。

class ApiClass:    def __init__(self):        self._value = 5        def get(self):        return self._valueclass Child(ApiClass):    def __init__(self):        super().__init__()        self._value = 'hello' # 冲突了,覆盖掉父类        a = Child()print(f'{a.get()} and {a._value} should be different')
hello and hello should be different

属性名越常见,越容易冲突。为了减少冲突,我们可以把超类的属性设计成private属性,使子类的属性名不太可能与超类重复。

class ApiClass:    def __init__(self):        self.__value = 5        def get(self):        return self.__valueclass Child(ApiClass):    def __init__(self):        super().__init__()        self._value = 'hello'        a = Child()print(f'{a.get()} and {a._value} should be different')
5 and hello should be different

#43 自定义的容器类型应该从collections.abc继承

下面我们自定义一种list类型,继承自list,提供了frequency方法,计算每个元素出现的次数。

class FrequencyList(list):    def __init__(self, members):        super().__init__(members)        def frequency(self):        counts = {}        for item in self:            counts[item] = counts.get(item, 0) + 1        return counts

继承list类,就可以自动获得标注的Python列表所具有的各项功能。这样就可以像使用普通列表那样使用FrequencyList了。

foo = FrequencyList(['a', 'b', 'a', 'c', 'a', 'd'])print('Length is',len(foo))foo.pop()print('After pop:', repr(foo))print('Frequency:', foo.frequency())
Length is 6
After pop: ['a', 'b', 'a', 'c', 'a']
Frequency: {'a': 3, 'b': 1, 'c': 1}

有时,某个对象所属的类本身虽然不是list的子类,但我们还是想让它能像list那样,可以通过下标来访问。
例如,表示二叉树节点的BinaryNode类就不是list的子类,我们想让它能像序列那样通过下标来访问。

class BinaryNode:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right

此时就需要实现一些特殊的实例方法。当通过下标访问序列中的元素时:

bar = [1, 2, 3]
bar[0]
1

Python会把访问操作解读为:

bar.__getitem__(0)

所以我们可以定义__getitem__方法:

class IndexableNode(BinaryNode):
    
    def _traverse(self):
        if self.left is not None:
            yield from self.left._traverse()
        yield self
        if self.right is not None:
            yield from self.right._traverse()
    
    def __getitem__(self, index):
        for i, item in enumerate(self._traverse()):
            if i == index:
                return item.value
        raise IndexError(f'Index {index} is out of range')

我们可以像使用BinaryNode那样,用这种定制过的IndexableNode对象来构造二叉树。

tree = IndexableNode(
    10,
    left=IndexableNode(
        5,
        left=IndexableNode(2),
        right=IndexableNode(
            6,
            right=IndexableNode(7))),
    right=IndexableNode(
        15,
        left=IndexableNode(11)))

我们既可以像访问列表那样访问,又可以通过leftright属性来访问。

print('LRR is', tree.left.right.right.value)
print('Index 0 is', tree[0])
print('Index 1 is', tree[1])
print('11 in the tree?', 11 in tree)
print('17 in the tree?', 17 in tree)
print('Tree is', list(tree))

LRR is 7
Index 0 is 2
Index 1 is 5
11 in the tree? True
17 in the tree? False
Tree is [2, 5, 6, 7, 10, 11, 15]

问题是,除了通过下标索引,list实例还支持其他的功能,所以只实现__getitem__这样一个特殊方法是不够的。例如,我们不能获取长度:

len(tree)
---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-10-dc0343ec22f7> in <module>
----> 1 len(tree)


TypeError: object of type 'IndexableNode' has no len()

要想让定制的二叉树支持内置的len函数,必须再实现一个特殊方法,__len__方法。

class SequenceNode(IndexableNode):
    def __len__(self):
        for count, _ in enumerate(self._traverse(), 1):
            pass
        return count

tree = SequenceNode(
    10,
    left=SequenceNode(
        5,
        left=SequenceNode(2),
        right=SequenceNode(
            6,
            right=SequenceNode(7))),
    right=SequenceNode(
        15,
        left=SequenceNode(11)))

print('Tree length is', len(tree))

Tree length is 7

哪怕是这样,我们还是无法让这种二叉树具备列表所应支持的全套功能。
为了方便大家定制容器,Python内置的collections.abc模块定义了一系列抽象基类(abstract base class),把每种容器类型应该提供的所有常用方法都写了出来。
我们只需要从这样的抽象基类里面继承就好。

from collections.abc import Sequence

class BadType(Sequence):
    pass

foo = BadType()
---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-12-997f61651b50> in <module>
      4     pass
      5 
----> 6 foo = BadType()


TypeError: Can't instantiate abstract class BadType with abstract methods __getitem__, __len__

如果忘记实现某些必须的方法,Python还会“友好地”给你提示。
一旦这些必备的方法都实现好了,我们就可以从collections.abc模块的抽象基类里面继承了。

例如下面这个BetterNode二叉树类就是正确的,因为它通过继承前面的类实现了序列容器所应支持的全部必备方法,至于其他一些方法(如indexcount)则会由Sequence这个抽象基类自动帮我们实现。

class BetterNode(SequenceNode, Sequence):
    pass

tree = BetterNode(
    10,
    left=BetterNode(
        5,
        left=BetterNode(2),
        right=BetterNode(
            6,
            right=BetterNode(7))),
    right=BetterNode(
        15,
        left=BetterNode(11)))

print('Index of 7 is', tree.index(7))
print('Count of 10 is', tree.count(10))
print('Tree length is', len(tree))
Index of 7 is 3
Count of 10 is 1
Tree length is 7

collections.abc模块要求子类必须实现某些特殊方法,另外,Python在比较或排序对象时,还会用到其他一些特殊方法,无论定制的是不是容器类,有时为了支持某些功能,都必须定义相关的特殊方法才行。


  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

愤怒的可乐

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值