在PyTorch的很多函数中都会包含 forward()
函数,但是也没见哪里调用过forward()
函数,不免让人产生疑惑
想要了解 forward()
函数的作用,首先要了解 Python 中的 __ call __
函数,
__ call __
函数的作用是能够让python中的类能够像方法一样被调用,通过下边的例子来理解一下:
class X(object):
def __init__(self, a, b, range):
self.a = a
self.b = b
self.range = range
def __call__(self, a, b):
self.a = a
self.b = b
print('__call__ with ({}, {})'.format(self.a, self.b))
def __del__(self, a, b, range):
del self.a
del self.b
del self.range
调用:
>>> xInstance = X(1, 2, 3)
>>> xInstance(1,2)
__call__ with (1, 2)
xInstance = X(1, 2, 3)
这句代码实例化了xInstance类型的对象X,在实例化时调用了__init__(self, a, b, range)
函数;xInstance(1,2)
代码使用类+参数的语法 直接调用了__call__(self, a, b)
函数,而不需要使用X.__call__(self, a, b)
语法,这就是__call__(self, a, b)
函数的作用,能够让python中的类能够像方法一样被调用
因为 PyTorch 中的大部分方法都继承自 torch.nn.Module,而 torch.nn.Module 的__call__(self)
函数中会返回 forward()函数
的结果,因此PyTroch中的 forward()函数
等于是被嵌套在了__call__(self)
函数中;因此forward()函数
可以直接通过类名被调用,而不用实例化对象