学习PyTorch时,对如下代码感到疑惑:
net = MobileNetV2(num_classes=5) #实例类,创建类的对象net
logits = net(images) #前向传播
第二句一直不太理解,MobileNetV2类中用于前向传播的函数是forward(),一般来说,如果想调用forward函数,需要通过net.forward(images)调用,为什么net(images)就能直接调用类中的forward()函数?
后来找到MobileNetV2的父类Module,发现Module类也有一个forward()函数,然后发现了Module类的__call__()方法调用了forward函数。
#MobileNetV2类中__call__()方法部分代码
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
所以猜想和__call__有关,于是写了一个测试代码,代码如下:
class module:
def __call__(self,a):
print("调用父类的__call__()函数", a)
y=self.forward(a,2)
def forward(self,input,input2):
print("调用父类的forward()函数",input,input2)
class res(module):
def forward(self,x,y):
print("调用子类的forward()函数",x,y)
net=res()
net(1)
输出结果如下:
调用父类的__call__()函数 1
调用子类的forward()函数 1 2
由上述代码和结果可以发现,当执行net(1)时,会先调用父类的__call__()函数,然后调用子类的forward()函数。
后来查阅资料发现,这种结果是在继承机制和内置方法__call__()的共同作用下取得的:
1、子类继承父类,这样当子类对象调用某一函数时,Python解释器会先去子类找该函数;如果找不到,它还会去父类中去找;如果子类有,就不会去父类找了。
2、对象通过__call__(slef, [,*args [,**kwargs]])方法可以模拟函数的行为,如果一个对象x提供了该方法,就可以像函数一样使用它,也就是说x(arg1, arg2…) 等同于调用x.call(self, arg1, arg2) 。