群里的大佬同学讨论了一个这样的问题:
源码取自:
https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/flops_counter.py
以下这个 函数.__get__
的用法是啥?
def add_flops_counting_methods(net_main_module):
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
net_main_module.start_flops_count = start_flops_count.__get__( # <---- 以下几处 __get__ 什么意思
net_main_module)
net_main_module.stop_flops_count = stop_flops_count.__get__(
net_main_module)
net_main_module.reset_flops_count = reset_flops_count.__get__(
net_main_module)
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501
net_main_module)
net_main_module.reset_flops_count()
return net_main_module
而 start_flops_count
函数是这样定义的
def start_flops_count(self):
"""Activate the computation of mean flops consumption per image.
A method to activate the computation of mean flops consumption per image.
which will be available after `add_flops_counting_methods()` is called on a
desired net object. It should be called before running the network.
"""
add_batch_counter_hook_function(self)
def add_flops_counter_hook_function(module):
if is_supported_instance(module):
if hasattr(module, '__flops_handle__'):
return
else:
handle = module.register_forward_hook(
MODULES_MAPPING[type(module)])
module.__flops_handle__ = handle
self.apply(partial(add_flops_counter_hook_function))
凭感觉,这样做的作用是 相当于在外部写了一个类内的方法,注释也说了:
adding additional methods to the existing module object,
this is done this way so that each function has access to self object
这个 __get__
到底做了什么呢?
参考官方文档:
https://docs.python.org/zh-cn/3/howto/descriptor.html
(我竟然没看到后面部分,然后去翻了CPython源码… 我同学找到的)
class Function:
...
def __get__(self, obj, objtype=None):
"Simulate func_descr_get() in Objects/funcobject.c"
if obj is None:
return self
return MethodType(self, obj)
而 MethodType 是什么:
class MethodType:
"Emulate PyMethod_Type in Objects/classobject.c"
def __init__(self, func, obj):
self.__func__ = func
self.__self__ = obj
def __call__(self, *args, **kwargs):
func = self.__func__
obj = self.__self__
return func(obj, *args, **kwargs)
这里相当于把参数传入函数中,将该函数返回,但未执行,感觉像是:
from functools import partial
partial(func, arg)
在举个例子:
def print_char(char):
print(char)
return char
print_A = print_char.__get__("A")
print_A()
A