模型训练时,发现只需要实例化一个模型对象并传入对应参数即可自动调用forward函数,对于其中原理查了些资料。
forward的使用,实例化模型后,自动调用forward进行前向传播,当然效果其实与module.forward(data)一样。
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
# ......
def forward(self, x):
# ......
return x
data = ..... #输入数据
# 实例化一个对象
module = Module()
# 前向传播
module(data)
# 而不是使用下面的
# module.forward(data)
其中隐含的原理在于,在python中定义一个class类时,会有许多魔法函数,其中__call__函数是本节关键。
class A():
def __call__(self):
print('i can be called like a function')
a = A()
a()
output:i can be called like a function
__call__里调用其他的函数:
class A():
def __call__(self, param):
print('i can called like a function')
print('传入参数的类型是:{} 值为: {}'.format(type(param), param))
res = self.forward(param)
return res
def forward(self, input_):
print('forward 函数被调用了')
print('in forward, 传入参数类型是:{} 值为: {}'.format( type(input_), input_))
return input_
a = A()
input_param = a('i')
print("对象a传入的参数是:", input_param)
这时,输出为:
i can called like a function
传入参数的类型是:<class ‘str’> 值为: i
forward 函数被调用了
in forward, 传入参数类型是:<class ‘str’> 值为: i
对象a传入的参数是: i
在构建网络模型时,比如class A(nn.module),其中nn.module中包含了__call__函数,并且该函数已经定义了forward函数,由于模型A继承了mnn.module,所以模型A同样具有__call__函数功能。
2、__call__函数详解
2.1、__call__魔法函数的使用
示例代码:
class A(object):
def __init__(self, name, age):
self.name = name
self.age = age
def __call__(self):
print('my name is %s' % self.name)
print('my age is %s' % self.age)
if __name__ == '__main__':
a = A('dgw', 25, )
a()
运行结果:
my name is dgw
my age is 25
将A实例化后(a = A('dgw', 25, )),此时直接调用实例a,即a()就是调用其__call__方法,这个函数使该对象A变成了一个可调用对象,可以调用,也可以通过__call__函数为它增加参数。
示例代码:
class A(object):
def __init__(self, name, age):
self.name = name
self.age = age
def __call__(self, male):
print('my name is %s' % self.name)
print('my age is %s' % self.age)
print('my male is %s' % male)
if __name__ == '__main__':
a = A('dgw', 25, )
a('woman')
output:
my name is dgw
my age is 25
my male is woman
允许一个类的实例像函数一样被调用。实质上说,这意味着 x() 与 x.__call__() 是相同的。注意 __call__ 参数可变。这意味着你可以定义 __call__ 为其他你想要的函数,无论有多少个参数。
__call__ 在那些类的实例经常改变状态的时候会非常有效。调用这个实例是一种改变这个对象状态的直接和优雅的做法。
示例代码:
class A(object):
def __init__(self, name, age, male):
self.name = name
self.age = age
self.male = male
def __call__(self, name, age):
self.name, self.age = name, age
if __name__ == '__main__':
a = A('dgw', 25, 'man')
print(a.age, a.name)
a('zhangsan', 52)
print(a.name, a.age)
print(a.age, a.name)
output:
25 dgw
zhangsan 52
52 zhangsan
2.2 作为装饰器
class Decorator(object):
def __init__(self, name):
self.name = name
def __call__(self, func):
def wrapper(*args, **kwargs):
print(f"before func {func.__name__}")
result = func(*args, **kwargs)
print(f"after func {func.__name__}")
return result
return wrapper
@Decorator(name='dgw')
def my_func(x, y=10):
return x + y
if __name__ == '__main__':
ret = my_func(5)
print(ret)
output:
before func my_func
after func my_func
15
参考链接: