0. Abstract
最近看到 Python 的装饰器,于是初步调查了其用法,主要包括两类:
- decorator
- register
1. Decorator
直接看一个简单的例子:
import time
# 自定义装饰器函数,用于计算函数执行时间
def calculate_time(func): # 调用被装饰的函数, 实际会重定向到此处
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs) # 在这调用被装饰函数
end_time = time.time()
print(f'Function {func.__name__} took {end_time - start_time} seconds to execute.')
return result
return wrapper # 返回对象是一个函数, 实际执行者
# 使用装饰器修饰函数
@calculate_time
def my_function(n):
result = 0
for i in range(n):
result += i
return result
# 调用被装饰后的函数
print('Result:', my_function(100000))
输出:
Function my_function took 0.003997802734375 seconds to execute.
Result: 4999950000
当用 @calculate_time
装饰函数 my_function(n)
后,调用该函数时,会重定向到装饰器函数 calculate_time(func)
,其参数是被装饰的函数对象 my_function
,返回对象是其内部定义的 wrapper
。
那为什么实际的执行者是 wrapper
呢?我们可以看到,其参数和 my_function
是一样的,所以可以怀疑装饰器干的事就是“偷梁换柱”,当调用 my_function(100000)
时,实际被装饰器改成了 wrapper(100000)
。
1.1 验证“偷梁换柱”
通过两种方式验证:(1) 修改装饰器的返回对象;(2) 修改 wrapper
的函数标签。
- 修改装饰器的返回对象
我们让装饰器返回一个简单的打印函数print_n
,而不是wrapper
:
import time
def print_n(n):
print(n)
return '没想到执行的是我吧!'
# 自定义装饰器函数,用于计算函数执行时间
def calculate_time(func): # 调用被装饰的函数, 实际会重定向到此处
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs) # 在这调用被装饰函数
end_time = time.time()
print(f'Function {func.__name__} took {end_time - start_time} seconds to execute.')
return result
return print_n # 返回对象是一个函数, 实际执行者
# 使用装饰器修饰函数
@calculate_time
def my_function(n):
result = 0
for i in range(n):
result += i
return result
# 调用被装饰后的函数
print('Result:', my_function(100000)) # 实际执行 print_n(100000)
输出:
100000
Result: 没想到执行的是我吧!
实际的执行者是 print_n
,也就是说“偷梁换柱”的结果是将 my_function
换成了 print_n
。
- 修改
wrapper
的函数标签
给wrapper
添加两个参数:
def calculate_time(func):
def wrapper(a, b, *args, **kwargs): # 添加了两个参数
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f'Function {func.__name__} took {end_time - start_time} seconds to execute.')
return result
return wrapper
执行报错:
TypeError: calculate_time.<locals>.wrapper() missing 1 required positional argument: 'b'
即,执行函数 wrapper
时,少了一个参数 b
,可见其只得到了一个参数 100000
,再次证明程序执行了 wrapper
,“偷梁换柱”。
小结:装饰器的作用就是“偷梁换柱”,调用被装饰函数时会替换一个签名一致的其他函数对象,在这个对象中,你可以添加一些类似计算执行时间之类的额外功能,当然,记得调用被装饰函数哦!
1.2 再举一个例子
def lazy_property(fn):
attr_name = '_lazy_' + fn.__name__
@property
def _lazy_property(self): # 实际被调用者
print('实际执行的是我')
if not hasattr(self, attr_name):
setattr(self, attr_name, fn(self)) # 在这执行一次
return getattr(self, attr_name)
return _lazy_property
class MyClass:
def __init__(self):
self._data = None
@lazy_property
def data(self):
print('Calculating data...')
self._data = [1, 2, 3]
return self._data
# 创建对象
obj = MyClass()
# 访问属性
print(obj.data) # 第一次访问,会计算并缓存数据
print(obj.data) # 第二次访问,直接返回缓存数据
其实我们原本会这么做:
def data(self):
if self._data is None:
print('Calculating data...')
self._data = [1, 2, 3]
return self._data
也能实现冷启动的属性访问。定义一个装饰器会使代码变得更简洁,因为它把一些逻辑(if
)移动到了装饰器中的打包函数内。
2. Register
直接看例子:
def register_kl(type_p, type_q):
if not isinstance(type_p, type) and issubclass(type_p, Distribution):
raise TypeError(
f"Expected type_p to be a Distribution subclass but got {type_p}"
)
if not isinstance(type_q, type) and issubclass(type_q, Distribution):
raise TypeError(
f"Expected type_q to be a Distribution subclass but got {type_q}"
)
def decorator(fun):
_KL_REGISTRY[type_p, type_q] = fun
_KL_MEMOIZE.clear() # reset since lookup order may have changed
return fun
return decorator
@register_kl(Normal, Normal)
def _kl_normal_normal(p, q):
var_ratio = (p.scale / q.scale).pow(2)
t1 = ((p.loc - q.loc) / q.scale).pow(2)
return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
这是 torch.distributions.kl
包中的注册器,作用是向字典 _KL_REGISTRY = {(Distribution,Distribution): kl_function}
中添加 KL 散度计算函数,用于计算不同分布(class Distribution
子类) 之间的 KL 散度。上诉代码 @register_kl(Normal, Normal)
是计算两个高斯分布之间的 KL 散度。
注意到,register_kl
的参数和返回值都与装饰器不同了,它需要参数((Normal, Normal)
),返回值是一个装饰器。在 register_kl
函数体内实现了一些参数检查,而定义在其内部的装饰器 decorator
才真正地实现了主要功能:把函数存进字典 _KL_REGISTRY
中。
需要注意的是,函数 _kl_normal_normal
是以 _
开头的,说明其不对外开放。要想计算分布之间的 KL 散度,需要调用函数:
def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor:
try:
fun = _KL_MEMOIZE[type(p), type(q)]
except KeyError:
fun = _dispatch_kl(type(p), type(q))
_KL_MEMOIZE[type(p), type(q)] = fun
if fun is NotImplemented:
raise NotImplementedError(
f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}"
)
return fun(p, q)
这个函数试图根据分布类型从 _KL_REGISTRY
中找到最具体的 KL 散度计算函数,并调用它完成 KL 散度的计算。
小结: 注册器比装饰器多了一层,其内部定义了装饰器。
2.1 既然是不给调用的函数,那注册器的工作机制就不会是“偷梁换柱”了吧?
依然是“偷梁换柱”,只不过装饰器外多了一层注册器。没有对函数 _kl_normal_normal
的调用,但会以其他的形式调,如上面的 kl_divergence
。当
from torch.distributions import kl
的时候,就会自动执行注册器,将各种被 @register_kl(***, ***)
修饰的函数添加到字典 _KL_REGISTRY
中,我们添加一些 print()
语句来验证一下:
def register_kl(type_p, type_q):
print('检查 p, q 类型', type_p, type_q) # 测试
if not isinstance(type_p, type) and issubclass(type_p, Distribution):
raise TypeError(
f"Expected type_p to be a Distribution subclass but got {type_p}"
)
if not isinstance(type_q, type) and issubclass(type_q, Distribution):
raise TypeError(
f"Expected type_q to be a Distribution subclass but got {type_q}"
)
def decorator(fun):
_KL_REGISTRY[type_p, type_q] = fun
print('注册函数', fun.__qualname__) # 测试
_KL_MEMOIZE.clear() # reset since lookup order may have changed
return fun
return decorator
再 import
的话,则输出:
检查 p, q 类型 <class 'torch.distributions.bernoulli.Bernoulli'> <class 'torch.distributions.bernoulli.Bernoulli'>
注册函数 _kl_bernoulli_bernoulli
检查 p, q 类型 <class 'torch.distributions.beta.Beta'> <class 'torch.distributions.beta.Beta'>
注册函数 _kl_beta_beta
检查 p, q 类型 <class 'torch.distributions.binomial.Binomial'> <class 'torch.distributions.binomial.Binomial'>
注册函数 _kl_binomial_binomial
...
看来 kl
包内的 @register_kl(***, ***)
都执行了。没错!在没有调用 _kl_binomial_binomial
的情况下,装饰器 decorator
也执行了。
2.2 自定义了新的分布的话(subclass Distribution),可自行注册
@register_kl(VonMisesFisher, HypersphericalUniform)
def _kl_vmf_uniform(vmf, hyu):
return -vmf.entropy() + hyu.entropy()
其中 VonMisesFisher, HypersphericalUniform
是自定义的 Distribution
的子类,该函数实现了两个分布之间 KL 散度的计算。注册之后,再调用 kl.kl_divergence
就可以实现 KL 散度计算了。
问: 为何不直接调用 _kl_vmf_uniform(vmf, hyu)
呢?
答: 大概是为了代码的泛化性和简洁性吧。
2.3 其实“偷梁换柱”的概念还在
def sub(a, b):
return a - b
def register_operator(type_a, type_b):
print('检查类型')
print(type_a, type_b)
def decorator(fun):
print('注册函数', fun.__qualname__)
return sub # 注意这里返回了 sub
return decorator
@register_operator(int, int)
def _add(a, b):
print("Adding")
print(_add(1, 2)) # output: -1
更改 decorator
的返回值为 sub
,输出:
检查类型
<class 'int'> <class 'int'>
记录映射_add
注册函数 _add
-1
可见,注册其照常执行,而 _add(a, b)
内的 print("Adding")
并未执行,而是执行了 sub(1, 2)
,偷梁换柱的机制依然还在。
补充
装饰器和注册器的执行时机:
当 import
其所在模块时,会执行模块的内容,主要是函数定义:
@calculate_time # 执行这一行时会执行装饰器
def my_function(n):
result = 0
for i in range(n):
result += i
return result
此时会执行 calculate_time
,验证代码:
# 自定义装饰器函数,用于计算函数执行时间
def calculate_time(func):
print('执行吗')
def wrapper(*args, **kwargs): # 调用被装饰的函数, 实际会重定向到此处
start_time = time.time()
result = func(*args, **kwargs) # 在这调用被装饰函数
end_time = time.time()
print(f'Function {func.__name__} took {end_time - start_time} seconds to execute.')
return result
return wrapper # 返回对象是一个函数, 实际执行者
会输出:
执行吗
但不会执行 wrapper
【没法执行啊,都没参数】。
注册器也一样,注册器和其内的装饰器也会执行,验证代码:
######################### register.py #########################
def sub(a, b):
return a - b
def register_operator(type_a, type_b):
print('检查类型')
print(type_a, type_b)
def decorator(fun):
print(f'记录映射{fun.__qualname__}')
return sub # 注意这里返回了 sub
return decorator
@register_operator(int, int)
def _add(a, b):
print("Adding")
当 import register
时,输出:
检查类型
<class 'int'> <class 'int'>
记录映射_add
小结: 装饰器和注册器在 import
的时候都会执行。