文章目录
0. 前言
按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。
1. 关于forward的两个小问题
1.1 为什么都用def forward,而不改个名字?
在Pytorch建立神经元网络模型的时候,经常用到forward方法,表示在建立模型后,进行神经元网络的前向传播。说的直白点,forward就是专门用来计算给定输入,得到神经元网络输出的方法。
在代码实现中,也是用def forward
来写forward前向传播的方法,我原来以为这是一种约定熟成的名字,也可以换成任意一个自己喜欢的名字。
但是看的多了之后发现并非如此:Pytorch对于forward方法赋予了一些特殊“功能”
(这里不禁再吐槽,一些看起来挺厉害的Pytorch“大神”,居然不知道这个。。。只能草草解释一下:“就是这样的。。。”)
1.2 forward有什么特殊功能?
第一条:.forward()可以不写
我最开始发现forward()的与众不同之处就是在此,首先举个例子:
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
print(T(6))
# print(T.forward(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
Process finished with exit code 0
可以发现,T(6)是可以输出的!而且不用指定,默认了调用forward方法
。当然如果非要写上.forward()这也是可以正常运行的,和不写是一样的。
如果不调用Pytorch(正常的Python语法规则),这样肯定会报错的
# import torch.nn as nn #不再调用torch
class test():
def __init__(self, input):
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
print(T.forward(6))
print("************************")
print(T(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
************************
Traceback (most recent call last):
File "C:\Users\Lenovo\Desktop\DL\pythonProject\tt.py", line 77, in <module>
print(T(6))
TypeError: 'test' object is not callable
Process finished with exit code 1
这里会报:‘test’ object is not callable
因为class不能被直接调用,不知道你想调用哪个方法。
第二条:优先运行forward方法
如果在class中再增加一个方法:
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def byten(self):
return self.input * 10
def forward(self,x):
return self.input * x
T = test(8)
print(T(6))
print(T.byten())
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
80
Process finished with exit code 0
可以见到,在class中有多个method的时候,如果不指定method,forward是会被优先执行的。
2. 总结
在Pytorch中,forward方法是一个特殊的方法,被专门用来进行前向传播。
20230605 更新
应评论要求,增加forward的官方定义,这块我就不搬运PyTorch官网的内容了,直接传送门走你:nn.Module.forward。
20230919 大更新
首先非常感谢大家喜欢本文!这篇文章本来是我自己的“随手记”没想到有这么多C友浏览过!
其实在写完本文后我是有些遗憾的,因为本文仅是用了实验的方法探索出了.forward()
的表象,而它的运作机理却没有说明白,知其然不知其所以然!
在此感谢下面 Mr·小鱼 的评论给了我启迪,因为魔术方法__call__()
的特性确实很符合.forward()
的表象,但是我对着nn.Module
的源码一脸茫然,因为源码中压根没有__call__()
方法的定义!!
于是我抱着试试的心态,在PyTorch官网上查了下PyTorch的历史版本,这一查确实查到了线索:
下面是从PyTorch的上古版本v0.1.12中截取forward()
和__call__()
方法的源码:
class Module(object):
#...中间不相关代码省略...
def forward(self, *input):
"""Defines the computation performed at every call.
Should be overriden by all subclasses.
"""
raise NotImplementedError
#...中间不相关代码省略...
def __call__(self, *input, **kwargs):
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
raise RuntimeError(
"forward hooks should never return any values, but '{}'"
"didn't return None".format(hook))
var = result
while not isinstance(var, Variable):
var = var[0]
creator = var.creator
if creator is not None and len(self._backward_hooks) > 0:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
creator.register_hook(wrapper)
return result
我们可以看到在__call__()
方法中直接把方法self.forward()
作为函数的返回值,由于魔术方法__call__()
可以被自动调用,这也就解释了为什么forward()
可以自动运行。
至于该方法中的其他内容,都是与hook钩子函数的操作相关,这部分暂不做探索。。。
那我们回到现在的版本(我现在使用的是1.8.1):
通过源码可以看到经历了多个版本的更迭,forward()
和__call__()
居然改名字了!!
forward: Callable[..., Any] = _forward_unimplemented
...
__call__ : Callable[..., Any] = _call_impl
这里使用了类型注解(Type Annotation),用于指定变量或方法的类型。以forward: Callable[..., Any] = _forward_unimplemented
这行代码为例,其作用和含义如下:
forward: Callable[..., Any]
:- 表示
forward
是一个可调用对象(Callable
),即一个函数或方法。 Callable[..., Any]
是类型注解,表示forward
可以接受任意参数(...
),并返回任意类型的值(Any
)。
- 表示
= _forward_unimplemented
:- 表示
forward
的默认实现是_forward_unimplemented
,这是一个占位符方法,通常用于提示用户需要重写forward
方法。
- 表示
这也就是为什么我之前在源码中没找到这两个方法定义的原因。。。准确来说这里也不能说是改名字了,而是多了一个名字,至于PyTorch为什么会有这样的更改,我确实也没想到原因。。。
其中_forward_unimplemented()
倒是没变:
def _forward_unimplemented(self, *input: Any) -> None:
r"""Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError
而_call_impl()
相比于上古版本,已经复杂到了令人发指的地步!
def _call_impl(self, *input, **kwargs):
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
if len(self._backward_hooks) > 0 or len(_global_backward_hooks) > 0:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
for hook in itertools.chain(
_global_forward_pre_hooks.values(),
self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
if len(full_backward_hooks) > 0:
bw_hook = hooks.BackwardHook(self, full_backward_hooks)
input = bw_hook.setup_input_hook(input)
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in itertools.chain(
_global_forward_hooks.values(),
self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if bw_hook:
result = bw_hook.setup_output_hook(result)
# Handle the non-full backward hooks
if len(non_full_backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in non_full_backward_hooks:
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
return result
其变复杂的原因是各种钩子函数_hook的调用,有兴趣的童鞋可以参考这篇文章:pytorch 中_call_impl()函数。这部分绝对是超纲了!
最后我想再做几个实验加深理解:
实验①
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
print(T.__call__(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py
48
Process finished with exit code 0
这里T.__call__(6)
写法等价于T(6)
实验②
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
print(T.forward(6))
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py
48
Process finished with exit code 0
这里T.forward(6)
的写法虽然也能正确地计算出结果,但是不推荐这么写,理由我将在下面的【20250411更新】中说明。
我原以为这会导致
__call__()
调用一遍forward()
,然后手动又调用了一遍forward()
,造成forward()
的重复计算,浪费计算资源。但实际并不是这么回事!
实验③
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
# def forward(self,x):
# return self.input * x
T = test(8)
print(T())
--------------------------运行结果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py
Traceback (most recent call last):
File "C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py", line 11, in <module>
print(T())
File "D:\Users\Lenovo\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Users\Lenovo\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 201, in _forward_unimplemented
raise NotImplementedError
NotImplementedError
forward()
是必须要写的,因为__call__()
要自动调用forward()
。如果压根不写forward()
,__call__()
将无方法可以调用。按照forward()
的源码,这里会raise NotImplementedError
。
至此,我觉得PyTorch中的forward应该算是全说明白了。。。
20250411 更新
在上文中留下了一个问题:在代码中T.forward(6)
和T(6)
的结果是一致的,那为什么不推荐T.forward(6)
的写法?
这里我直接上结论:因为T(6)
的.forward()
方法是通过一系列的钩子(hook)实现的,而这些钩子可能还会实现一些其他的功能,为了保证计算的完整性(算完所有的钩子),需要通过 __call__
调用.forward()
调用方式 | 钩子是否触发 | 原因 |
---|---|---|
T(6) | ✅ 是 | 通过 __call__ 调用.forward() ,PyTorch自动触发注册的钩子。 |
T.forward(6) | ❌ 否 | 直接调用 .forward() ,绕过 PyTorch 的钩子机制。不推荐这么写。 |
这里,可以通过源码中的.register_forward_hook()
方法再设计一个实验。
我们假设创造了一个可以优化计算过程的.forward_hook_opt()
方法,就可以把它“钩在”钩子上:
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
def forward_hook_opt(module, input, output):
print(".forward() was optimized~")
hook_handle = T.register_forward_hook(forward_hook_opt)
T(6) #钩子生效
print("---------------------------------------------")
T.forward(6) #钩子未生效
hook_handle.remove()
输出:
.forward() was optimized~
---------------------------------------------
Process finished with exit code 0
这样我们就看出了触发和不触发钩子的区别了~
至此,我觉得🔥PyTorch中的forward我总算是懂一点了。。。