pytorch中的forward函数自动调用

文章解释了在Python中,如PyTorch的nn.Module中,实例化模型时如何通过__call__函数自动调用forward方法进行前向传播。此外,还介绍了__call__函数的使用,包括作为可调用对象、改变对象状态以及作为装饰器的应用。
摘要由CSDN通过智能技术生成

模型训练时,发现只需要实例化一个模型对象并传入对应参数即可自动调用forward函数,对于其中原理查了些资料。

forward的使用,实例化模型后,自动调用forward进行前向传播,当然效果其实与module.forward(data)一样。

class Module(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        # ......
       
    def forward(self, x):
        # ......
        return x

data = .....  #输入数据
# 实例化一个对象
module = Module()
# 前向传播
module(data)  
# 而不是使用下面的
# module.forward(data)   

其中隐含的原理在于,在python中定义一个class类时,会有许多魔法函数,其中__call__函数是本节关键。

class A():
    def __call__(self):
        print('i can be called like a function')
 
a = A()
a()

output:i can be called like a function

__call__里调用其他的函数:

class A():
    def __call__(self, param):
        
        print('i can called like a function')
        print('传入参数的类型是:{}   值为: {}'.format(type(param), param))
 
        res = self.forward(param)
        return res
 
    def forward(self, input_):
        print('forward 函数被调用了')
 
        print('in  forward, 传入参数类型是:{}  值为: {}'.format( type(input_), input_))
        return input_
 
a = A()
 
 
input_param = a('i')
print("对象a传入的参数是:", input_param)

这时,输出为:

i can called like a function
传入参数的类型是:<class ‘str’> 值为: i
forward 函数被调用了
in forward, 传入参数类型是:<class ‘str’> 值为: i
对象a传入的参数是: i

在构建网络模型时,比如class A(nn.module),其中nn.module中包含了__call__函数,并且该函数已经定义了forward函数,由于模型A继承了mnn.module,所以模型A同样具有__call__函数功能。

2、__call__函数详解

2.1、__call__魔法函数的使用

示例代码:

class A(object):
    def __init__(self, name, age):
        self.name = name
        self.age = age
 
    def __call__(self):
        print('my name is %s' % self.name)
        print('my age is %s' % self.age)
 
 
if __name__ == '__main__':
    a = A('dgw', 25, )
    a()

运行结果:

my name is dgw
my age is 25

将A实例化后(a = A('dgw', 25, )),此时直接调用实例a,即a()就是调用其__call__方法,这个函数使该对象A变成了一个可调用对象,可以调用,也可以通过__call__函数为它增加参数。

示例代码:

class A(object):
    def __init__(self, name, age):
        self.name = name
        self.age = age
 
    def __call__(self, male):
        print('my name is %s' % self.name)
        print('my age is %s' % self.age)
        print('my male is %s' % male)
 
 
if __name__ == '__main__':
    a = A('dgw', 25, )
    a('woman')

output:
my name is dgw
my age is 25
my male is woman

允许一个类的实例像函数一样被调用。实质上说,这意味着 x() 与 x.__call__() 是相同的。注意 __call__ 参数可变。这意味着你可以定义 __call__ 为其他你想要的函数,无论有多少个参数。

__call__ 在那些类的实例经常改变状态的时候会非常有效。调用这个实例是一种改变这个对象状态的直接和优雅的做法。
示例代码:

class A(object):
    def __init__(self, name, age, male):
        self.name = name
        self.age = age
        self.male = male
 
    def __call__(self, name, age):
        self.name, self.age = name, age
 
 
if __name__ == '__main__':
    a = A('dgw', 25, 'man')
    print(a.age, a.name)
    a('zhangsan', 52)
    print(a.name, a.age)
    print(a.age, a.name)
output:
    25 dgw
    zhangsan 52
    52 zhangsan

2.2 作为装饰器

class Decorator(object):
    def __init__(self, name):
        self.name = name
 
    def __call__(self, func):
        def wrapper(*args, **kwargs):
            print(f"before func {func.__name__}")
            result = func(*args, **kwargs)
            print(f"after func {func.__name__}")
            return result
        return wrapper
 
 
@Decorator(name='dgw')
def my_func(x, y=10):
    return x + y
 
 
if __name__ == '__main__':
    ret = my_func(5)
    print(ret)

output:

before func my_func
after func my_func
15

参考链接:

python中的__call__用法详解_def __call___IT之一小佬的博客-CSDN博客

pytorch中的forward函数详细理解-CSDN博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值