Python是面向对象的编程语言,它提供了多重继承的代码复用机制。但是,我们应该尽量避免使用多重继承。
如果一定要利用多重继承所带来的便利及封装性,那就考虑用它来编写mix-in类。mix-in是一种小型的类,它只定义了其他类可能需要提供的一套附加方法,而不定义自己的实例属性。此外,它也不要求使用者调用自己的__init__()构造器。
由于Python程序可以方便地查看各类对象的当前状态,所以编写mix-in比较容易。我们可以在mix-in类里面通过动态检测机制先编写一套通用的功能代码,稍后再利用继承机制将其应用到其他很多类上面。
例如,要把内存中的Python对象转换为字典形式,以便将其序列化。下列代码定义了实现该功能所用的mix-in类,在其中添加了一个public方法,其他任何类可以通过继承这个mix-in类来具备此功能:
class ToDictMixin:
def to_dict(self):
return self._traverse_dict(self.__dict__)
def _traverse_dict(self, instance_dict):
output = {}
for key, value in instance_dict.items():
output[key] = self._traverse(key, value)
return output
def _traverse(self, key, value):
if isinstance(value, ToDictMixin): # 如果value是一个支持to_dict方法的对象(TodictMixin的子类)
return value.to_dict()
elif isinstance(value, dict): # 否则如果value是一个字典对象
return self._traverse_dict(value)
elif isinstance(value, list): # 否则如果value是一个列表对象
return [self._traverse(key, i) for i in value]
elif hasattr(value, '__dict__'): # 否则如果value是一个具有__dict__属性的对象
return self._traverse_dict(value.__dict__)
else:
return value
下面演示如何用mix-in把二叉树表示为字典:
class BinaryTree(ToDictMixin):
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right
tree = BinaryTree(10,
left=BinaryTree(7, right=BinaryTree(9)),
right=BinaryTree(13, left=BinaryTree(11)))
print(tree.to_dict())
>>>
{'left': {'left': None,
'right': {'left': None, 'right': None, 'value': 9},
'value': 7},
'right': {'left': {'left': None, 'right': None, 'value': 11},
'right': None,
'value': 13},
'value': 10}
mix-in的最大优势在于,你可以把通用的功能做成插件,并在需要时覆写这些行为。例如我定义了一个二叉树的子类,每个节点存放着它的父节点的引用。假如采用默认的ToDictMixin.to_dict()来处理它,程序会因为循环引用而陷入死循环。
class BinaryTreeWithParent(BinaryTree):
def __init__(self, value, left=None,
right=None, parent=None):
super().__init__(value, left=left, right=right)
self.parent = parent
解决办法是在BinaryTreeWithParent类里覆写ToDictMixin._traverse()方法,令该方法只处理与序列化有关的值,从而使mix-in的实现代码不会陷入死循环。下面覆写的这个_traverse()方法,不再遍历parent节点,而只是把parent节点对应的value插入到最终生成的字典里面。
def _traverse(self, key, value):
if isinstance(value, BinaryTreeWithParent) and
key == 'parent'):
return value.value # Prevent cycles
else:
return super()._traverse(key, value)
现在调用BinaryTreeWithParent.to_dict()是不会有问题的,因为程序已经不再追踪导致循环引用的那个parent属性了。
root = BinaryTreeWithParent(10)
root.left = BinaryTreeWithParent(7, parent=root)
root.left.right = BinaryTreeWithParent(9, parent=root.left)
print(root.to_dict())
>>>
{'left': {'left': None,
'parent': 10,
'right': {'left': None,
'parent': 7,
'right': None,
'value': 9},
'value': 7},
'parent': None,
'right': None,
'value': 10}
定义了BinaryTreeWithParent._traverse()方法后,如果其他类的某个属性是BinaryTreeWithParent类型,那么ToDictMixin会自动地处理好这些属性。
class NamedSubTree(ToDictMixin):
def __init__(self, name, tree_with_parent):
self.name = name
self.tree_with_parent = tree_with_parent
my_tree = NamedSubTree('foobar', root.left.right)
print(my_tree.to_dict)) # No infinite loop
>>>
{'name': 'foobar',
'tree_with_parent': {'left': None,
'parent': 7,
'right': None,
'value': 9}}
多个mix-in之间也可以相互组合。例如,可以编写这样一个mix-in,它能够为任意类提供通用的JSON序列化功能。我们可以假定:继承了mix-in的那个类,会提供名为to_dict的方法(此方法可能是那个类通过多重继承ToDictMixin而具备的,也可能不是)。
class JsonMixin(object):
@classmethod
def from_json(cls, data):
kwargs = json.loads(data) # 把JSON格式的字符串转换为Python的字典对象
return cls(**kwargs)
def to_json(self):
return json.dumps(self.to_dict())
JsonMixin既定义了类方法又定义了实例方法。Mixin能让你添加任何一种行为。在这个例子中,JsonMixin的唯一要求是这个子类有to_dict实例方法,并且它的__init__方法接受关键字参数。
有了这样的mix-in之后,我们只需编写极少量的代码,就可以通过继承体系,轻松创建出相关的工具类,以便实现序列化数据以及从JSON中读取数据的功能。例如,如下这些类代表数据中心的拓扑结构:
class DatacenterRack(ToDictMixin, JsonMixin):
def __init__(self, switch=None, machines=None):
self.switch = Switch(**switch)
self.machines = [Machine(**kwargs) for kwargs in machines]
class Switch(ToDictMixin, JsonMixin):
def __init__(self, ports=None, speed=None):
self.ports = ports
self.speed = speed
class Machine(ToDictMixin, JsonMixin):
def __init__(self, cores=None, ram=None, disk=None):
self.cores = cores
self.ram = ram
self.disk = disk
对这样的类进行序列化,以及从JSON中加载它,都是比较简单的。下面这段代码,会重复执行序列化及反序列化操作,以验证这两个功能有没有正确地实现出来。
serialized = """{
"switch": {"ports": 5, "speed": 1e9},
"machines": [
{"cores": 8, "ram": 32e9, "disk": 5e12},
{"cores": 4, "ram": 16e9, "disk": 1e12},
{"cores": 2, "ram": 4e9, "disk": 500e9}
]
}"""
deserialized = DatacenterRack.from_json(serialized)
roundtrip = deserialized.to_json()
print(roundtrip)
assert json.loads(serialized) == json.loads(roundtrip)
>>>
{"switch": {"ports": 5, "speed": 1000000000.0}, "machines": [{"cores": 8, "ram": 32000000000.0, "disk": 5000000000000.0}, {"cores": 4, "ram": 16000000000.0, "disk": 1000000000000.0}, {"cores": 2, "ram": 4000000000.0, "disk": 500000000000.0}]}