Python __call__详解

问题

今天看自己调用EfficientNet的源码发现了这么一段

class DogeNet(nn.Module):
    def __init__(self):
        super(DogeNet, self).__init__()


        model = EfficientNet.from_pretrained('efficientnet-b4')
        model._fc = nn.Linear(1792, 2)
        self.efficientnet = model

    def forward(self, img):
        out = self.efficientnet(img)
        return out

model = torch.load('model/210efficient_b4.pth').to(device)
outputs = model(inputs)

这里就有了一个疑问,

model是一个对象,为什么可以直接接收input进行forward的操作?

推断一

这里猜测父类nn.Module的__init__方法有两个

一个是直接def __init__(self): 初始化方法,如

model = torch.load('model/210efficient_b4.pth').to(device)

一个是

def __init__(self,input):

return forward(input)

如outputs = model(inputs)

随后去验证猜想发现不是,nn.Module只有一个init方法

猜想二

随后开始找在哪里用了forward方法,找到了__call__函数

   def __call__(self, *input, **kwargs):
        for hook in self._forward_pre_hooks.values():
            result = hook(self, input)
            if result is not None:
                if not isinstance(result, tuple):
                    result = (result,)
                input = result
        if torch._C._get_tracing_state():
            result = self._slow_forward(*input, **kwargs)
        else:
            result = self.forward(*input, **kwargs)
        for hook in self._forward_hooks.values():
            hook_result = hook(self, input, result)
            if hook_result is not None:
                result = hook_result
        if len(self._backward_hooks) > 0:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in self._backward_hooks.values():
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
        return result

明显是该函数接收了input并返回前向传播的结果

这里这个函数明显赋予了实例化的对象可以直接被当作函数使用的功能

总结

在Python中,方法也是一种高等的对象。这意味着他们也可以被传递到方法中就像其他对象一样。这是一个非常惊人的特性。 在Python中,__call__让类的实例的行为表现的像函数一样,你可以调用他们,将一个函数当做一个参数传到另外一个函数中等等。这是一个非常强大的特性让Python编程更加舒适甜美。

简单说来,__call__相当于把()做了标识符重载,在使用类初始化对象的时候,调用类中的__init__方法,而程序中使用对象()的时候,调用类中的__call__方法,使得对象本身成为了函数。

ps:

另外有的人就要问了,我直接重写__init__函数不香吗

这里其实就犯了一个本质的错误。

init 是给类初始化用的,call是将类的实例可用作函数形式。

model = torch.load('model/210efficient_b4.pth').to(device)
outputs = model(inputs)

这里是model()而不是Module()

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值