Python类中self参数 / __ init__ ()方法 /super(Model, self).__init__()是什么

在使用tensorFlow2搭建神经网络模型的时候,除了Sequential的方法,还有就是自己写模型class,然后初始化。自己写模型类的优势就是可以自定义层与层之间的连接关系,自定义数据流x的流向。

这是为鸢尾花数据做的一个简单基础的神经网络模型代码:

class IrisModel(Model): #继承TF的Model类
    def __init__(self): #在这里定义网络结构块
        super(IrisModel, self).__init__()
        self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2()) #定义网络结构块

    def call(self, x): #在这里实现前向传播
        y = self.d1(x) #调用网络结构块,实现前向传播
        return y

model = IrisModel() #类实例化

我在第一次学习用类来写这个代码的时候,有几个地方觉得很奇怪,不习惯。

1、为什么要定义 def __init__()

2、为什么要有 super(MyModel ,self).__init__()

3、为什么到处都有一个self? 又不能省略,又感觉很累赘,它到底意义何在?

我本来也是打算自己写文章说说这点的,但是发现网上已经有人写了,并且写的很细致和清晰了,所以,我就直接转载他的文章吧。

博文:self参数 - __ init__ ()方法 super(Net, self).__init__()是什么

博主Chou_pijiang,原文链接:https://blog.csdn.net/zyh19980527/article/details/107206483

总结下来,有这样几点:

1、self参数

self指的是实例Instance本身,在Python类中规定,函数的第一个参数是实例对象本身,并且约定俗成,把其名字写为self,也就是说,类中的方法的第一个参数一定要是self,而且不能省略。
我觉得关于self有三点是很重要的:

self指的是实例本身,而不是类
self可以用this替代,但是不要这么去写
类的方法中的self不可以省略

 

2、__ init__ ()方法

在python中创建类后,通常会创建一个 __ init__ ()方法,这个方法会在创建类的实例的时候自动执行。 __ init__ ()方法必须包含一个self参数,而且要是第一个参数

 __ init__ ()方法在实例化的时候就已经自动执行了,但是如果不是 __ init__ ()方法,那肯定就只有调用才执行。如果 __ init__ ()方法中还需要传入另一个参数name,但是我们在创建Bob的实例的时候没有传入name,那么程序就会报错, 说我们少了一个__ init__ ()方法的参数,因为__ init__ ()方法是会在创建实例的过程中自动执行的,这个时候发现没有name参数,肯定就报错了。

那么什么需要在__ init__ ()方法中定义?就是当我们认为一些属性、操作是在创建实例的时候就有的时候,就应该把这个量定义在__ init__ ()方法中。我们写神经网络的代码的时候,一些网络结构的设置,也最好放在__ init__ ()方法中。

3、super(MyModel, self).__init__()

简单理解就是子类把父类的__init__()放到自己的__init__()当中,这样子类就有了父类的__init__()的那些东西。

Net类继承nn.Module,super(Net, self).__init__()就是对继承自父类nn.Module的属性进行初始化。而且是用nn.Module的初始化方法来初始化继承的属性。

子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化。
当然,如果初始化的逻辑与父类的不同,不使用父类的方法,自己重新初始化也是可以的。

写神经网络模型的时候,父类Modle中还有一些属性我们不用再自己手写了,直接将其初始化就好,另外Model也是有继承的父类network.Network,这里也有一些我们不知道、不知如何自己写的属性和方法。所以最好都加上 super(MyModel,self).__init__()这样的父类初始化命令行。

参考TensorFLow官方API中Model的代码,我们来看看Model初始化__init__()中都定义了什么:https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/module/module.py#L286-L317

class Module(tracking.AutoTrackable):
  def __init__(self, name=None):
    if name is None:  #模型名称
      name = camel_to_snake(type(self).__name__)
    else:
      if not valid_identifier(name):
        raise ValueError(
            "%r is not a valid module name. Module names must be valid Python "
            "identifiers (e.g. a valid class name)." % name)

    self._name = name
    if tf2.enabled():
      with ops.name_scope_v2(name) as scope_name:  #TF命名空间
        self._name_scope = ops.name_scope_v2(scope_name)
    else:
      with ops.name_scope(name, skip_on_eager=False) as scope_name:
        self._scope_name = scope_name

  • 21
    点赞
  • 55
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值