在建立CNN模型时,使用如下代码,在构建方法里面新建了一些对象,例如self.conv1,在下面的forward方法中直接把对象名作为方法名,传入变量x
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1=nn.Conv2d(3,6,5)
self.pool=nn.MaxPool2d(2,2)
self.conv2=nn.Conv2d(6,16,5)
self.fc1=nn.Linear(5*5*6,120)
self.fc2=nn.Linear(120,84)
self.fc3=nn.Linear(84,10)
def forward(self, x):
x=self.conv1(x)
x=self.conv1.forward(x)
x=F.relu(x)
x=self.pool(x)
x=self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
return x
实际上x = self.conv1(x) 等价于 x = self.conv1.forward(x)
原因:self.conv1=nn.Conv2d(3,6,5)
其中nn.Conv2d是一个类,继承关系Conv2d ——> _Convnd ——> Module。在建立Module类时,定义了__call()__,说明 这个类及其子类的实例 都是是可调用的
def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
hook(self, input)
if torch.jit._tracing:
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
raise RuntimeError(
"forward hooks should never return any values, but '{}'"
"didn't return None".format(hook))
if len(self._backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
使用x=self.conv1(x) 等价于调用了forward(),是因为:
- Module定义了forward(),Conv2d, _Convnd 继承并改写了forward()
- 在Module的__call()__里面写了
....
result = self.forward(*input, **kwargs)
....
return result
所以,使用对象名作为方法名时,使用哪个方法,要看在 __call()__ 里面怎么写的。
在这里,调用的方法是 对象的 类的 forward()
做一个简单例子
class A():
def __init__(self):
self.a = 1
def func(self, input):
print('A_call '+ input)
def __call__(self, *args, **kwargs):
return self.func(*args)
class B(A):
def __init__(self):
super(B, self).__init__()
def func(self, input):
print('B_call ' + input)
a = A()
a('ok')
b = B()
b('ok')
输出:
A_call ok
B_call ok