Pyhton 对象名可以直接作为方法名调用

在建立CNN模型时,使用如下代码,在构建方法里面新建了一些对象,例如self.conv1,在下面的forward方法中直接把对象名作为方法名,传入变量x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.pool=nn.MaxPool2d(2,2)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(5*5*6,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
    def forward(self, x):
        x=self.conv1(x)
        x=self.conv1.forward(x)
        x=F.relu(x)
        x=self.pool(x)
        x=self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)  
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x

实际上x = self.conv1(x) 等价于 x = self.conv1.forward(x)

原因:self.conv1=nn.Conv2d(3,6,5)

其中nn.Conv2d是一个类,继承关系Conv2d ——> _Convnd ——> Module。在建立Module类时,定义了__call()__,说明 这个类及其子类的实例 都是是可调用的

def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        hook(self, input)
    if torch.jit._tracing:
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in self._forward_hooks.values():
        hook_result = hook(self, input, result)
        if hook_result is not None:
            raise RuntimeError(
                "forward hooks should never return any values, but '{}'"
                "didn't return None".format(hook))
    if len(self._backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
    return result

 

使用x=self.conv1(x) 等价于调用了forward(),是因为:

  1. Module定义了forward(),Conv2d, _Convnd 继承并改写了forward()
  2. 在Module的__call()__里面写了
....

result = self.forward(*input, **kwargs)

....

return result

 

所以,使用对象名作为方法名时,使用哪个方法,要看在 __call()__ 里面怎么写的。

在这里,调用的方法是 对象的 类的 forward()

 

做一个简单例子

class A():
    def __init__(self):
        self.a = 1
    def func(self, input):
        print('A_call '+ input)
    def __call__(self, *args, **kwargs):
        return self.func(*args)

class B(A):
    def __init__(self):
        super(B, self).__init__()
    def func(self, input):
        print('B_call ' + input)

a = A()
a('ok')

b = B()
b('ok')

输出:

A_call ok
B_call ok

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值