问题
今天看自己调用EfficientNet的源码发现了这么一段
class DogeNet(nn.Module):
def __init__(self):
super(DogeNet, self).__init__()
model = EfficientNet.from_pretrained('efficientnet-b4')
model._fc = nn.Linear(1792, 2)
self.efficientnet = model
def forward(self, img):
out = self.efficientnet(img)
return out
model = torch.load('model/210efficient_b4.pth').to(device)
outputs = model(inputs)
这里就有了一个疑问,
model是一个对象,为什么可以直接接收input进行forward的操作?
推断一
这里猜测父类nn.Module的__init__方法有两个
一个是直接def __init__(self): 初始化方法,如
model = torch.load('model/210efficient_b4.pth').to(device)
一个是
def __init__(self,input):
return forward(input)
如outputs = model(inputs)
随后去验证猜想发现不是,nn.Module只有一个init方法
猜想二
随后开始找在哪里用了forward方法,找到了__call__函数
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
明显是该函数接收了input并返回前向传播的结果
这里这个函数明显赋予了实例化的对象可以直接被当作函数使用的功能
总结
在Python中,方法也是一种高等的对象。这意味着他们也可以被传递到方法中就像其他对象一样。这是一个非常惊人的特性。 在Python中,__call__让类的实例的行为表现的像函数一样,你可以调用他们,将一个函数当做一个参数传到另外一个函数中等等。这是一个非常强大的特性让Python编程更加舒适甜美。
简单说来,__call__相当于把()做了标识符重载,在使用类初始化对象的时候,调用类中的__init__方法,而程序中使用对象()的时候,调用类中的__call__方法,使得对象本身成为了函数。
ps:
另外有的人就要问了,我直接重写__init__函数不香吗
这里其实就犯了一个本质的错误。
init 是给类初始化用的,call是将类的实例可用作函数形式。
model = torch.load('model/210efficient_b4.pth').to(device)
outputs = model(inputs)
这里是model()而不是Module()