def register_to_config(init):
r"""
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
shouldn't be registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""
@functools.wraps(init)
def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
"not inherit from `ConfigMixin`."
)
ignore = getattr(self, "ignore_for_config", [])
# Get positional arguments aligned with kwargs
new_kwargs = {}
signature = inspect.signature(init)
parameters = {
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
}
for arg, name in zip(args, parameters.keys()):
new_kwargs[name] = arg
# Then add all kwargs
new_kwargs.update(
{
k: init_kwargs.get(k, default)
for k, default in parameters.items()
if k not in ignore and k not in new_kwargs
}
)
new_kwargs = {**config_init_kwargs, **new_kwargs}
getattr(self, "register_to_config")(**new_kwargs)
init(self, *args, **init_kwargs)
return inner_init
#######################################################################
from functools import wraps
def a_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
"""A wrapper function"""
print('in wrapper, before func() ...')
# Extend some capabilities of func
print(func)
func(*args, **kwargs)
print('in wrapper, after func() ...')
return wrapper
@a_decorator
def first_function():
"""This is docstring for first function"""
print("first function")
@a_decorator
def second_function(a):
"""This is docstring for second function"""
print("second function")
print('a=',a)
print(first_function.__name__)
print(first_function.__doc__)
first_function()
print('='*30)
print(second_function.__name__)
print(second_function.__doc__)
second_function(12)
###############
python3.8/site-packages/diffusers/configuration_utils.py
register_to_config对应a_decorator
init对应func
inner_init对应wrapper
@functools.wraps(init)对应@wraps(func)
signature = inspect.signature(init)可以根据函数签名,取得不同调用函数所需的参数,做不同的(预处理)准备工作,再调用各个模型的init,从而实现以不同的模型的初始化
getattr(self, "register_to_config")(**new_kwargs)的作用是给每个来需要初始化的模型的类取得register_to_config方法,并用new_kwargs中的参数,给每个类的参数赋值:当Transformer2DModel需要init时,就先获得Transformer2DModel的register_to_config,再在register_to_config()中,用new_kwargs给Transformer2DModel的各个参数赋值,再调用Transformer2DModel的init()方法初始化!!!
getattr(self, "register_to_config")(**new_kwargs)
init(self, *args, **init_kwargs)
init
<function Transformer2DModel.__init__ at 0x7f2625550c10>
print(kwargs)
{'num_attention_heads': 8, 'attention_head_dim': 40, 'in_channels': 320,
'out_channels': None, 'num_layers': 1, 'dropout': 0.0, 'norm_num_groups': 32,
'cross_attention_dim': 768, 'attention_bias': False, 'sample_size': None,
'num_vector_embeds': None, 'patch_size': None, 'activation_fn': 'geglu',
'num_embeds_ada_norm': None, 'use_linear_projection': False, 'only_cross_attention': False,
'upcast_attention': False, 'norm_type': 'layer_norm', 'norm_elementwise_affine': True}
None
print(kwargs.items())
dict_items([('num_attention_heads', 8), ('attention_head_dim', 40), ('in_channels', 320),
('out_channels', None), ('num_layers', 1), (&#