起因
我前段时间在学 pytorch, 在网上看到这样一段代码:
import torch
from torch import nn
class MLP(nn.Module):
def __init__(self,**kwargs):
super(MLP,self).__init__(**kwargs)
self.hidden = nn.Linear(784,256)
self.act = nn.ReLU()
self.output = nn.Linear(256,10)
#所以这个网络的结构是 784 -> 256 -> 10, 中间用 ReLU 激活
def forward(self,x): # x是输入
o = self.act(self.hidden(x))
return self.output(o)
当时首先让我困惑的是这个 super() 方法,其次是这个神秘的参数**kwargs。在网上查了半天之后,super() 是懂了,但似乎没有人把 **kwargs 讲得足够清楚。后来我自己通过代码实践理解了它的用途。这两个东西都和类的继承有很大关联。
先来说 super().
SUPER()
这个 super() 很简单。假设我们现在有两个类,一个是子类,一个是父类,我们要想在子类直接调用父类的方法(也就是函数),就得用super(). 举个例子:
class class1:
def __init__(self): # 初始化
self.height = 0
self.speed = 0
class class2(class1): # class2 继承 class1
def __init__(self):
super(class2, self).__init__() # 使用 class1 的__init__() 方法
这里定义了两个类,class1 和 class2。 在声明一个class2 的对象时,会直接调用class1 的 __init__() 方法。即:
object = class2()
print(object.speed)
输出:
0
由于在 class2 的初始化方法中调用了 class1 的初始化方法,我们哪怕没有在 class2 的初始化方法中定义 speed 这个属性,由于 class1 的初始化方法定义了 speed 属性,object 依旧拥有这一属性。更进一步,我们看看这样的写法会有什么样的结果:
class class2(class1):
def __init__(self):
super(class2, self).__init__()
self.speed = 1
self.height = 1
object = class2()
print(object.speed)
输出:
1
这很好理解,在class1 的初始化方法运行过后,我们把 speed 和 height 的值覆盖了一遍。
**kwargs
这个东西就稍微有一些麻烦。**kwargs 的本质上是一个字典, 里面会装你喂给函数的参数名称以及它们对应的值。使用 .get() 就可以获得这些参数的值。我认为比起讲解,例子可能更加有力。下面是它的基础用法:
def cout(**kwargs):
for i in kwargs:
print(i, ":", kwargs.get(i))
cout(num1=1, num2=2, num3=3)
结果:
num1 : 1
num2 : 2
num3 : 3
与之相对的还有一个 *args:
def cout(*args):
for i in args:
print(i)
cout(1,2,3)
结果:
1
2
3
下面我们来看看它是怎么在继承中发挥作用的。我们先来定义一个类:
class class1:
def __init__(self, trueSpeed = 100):#默认为100
self.speed = trueSpeed
self.height = 0
object = class1() #使用默认值
print("speed by defalt:", object.speed)
object = class1(trueSpeed = 78) #手动设定
print("Manually setting speed:",object.speed)
输出:
speed by defalt: 100
Manually setting speed: 78
我们发现,可以通过设定 trueSpeed = 78来手动指定我们要的值。
现在有另外一个类:
class class2(class1): # class2 继承 class1
def __init__(self):
super(class2,self).__init__()
object = class2()
print("speed by defalt:", object.speed)
object = class2(trueSpeed = 78)
print("Manually setting speed:",object.speed)
输出:
speed by defalt: 100
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[148], line 8
5 object = class2()
6 print("speed by defalt:", object.speed)
----> 8 object = class2(trueSpeed = 78)
9 print("Manually setting speed:",object.speed)
TypeError: class2.__init__() got an unexpected keyword argument 'trueSpeed'
我们发现使用默认的值来初始化是没有问题的,但是我们没有办法手动设定值了。如果我们既想要在 class2 中使用 class1 的 __init__() 方法,又想指定 trueSpeed 的值,应该怎么办呢?这时候我们很自然地想到了 **kwargs, 它能够装我们喂给它的参数,不是吗?那么只需要这么写:
class class2(class1):
def __init__(self, **kwargs):
super(class2,self).__init__(**kwargs)
object = class2() # 默认,kwargs为空
print("speed by defalt:", object.speed)
object = class2(trueSpeed = 78) # kwargs 装载了 trueSpeed 的信息
print("Manually setting speed:",object.speed)
输出:
speed by defalt: 100
Manually setting speed: 78
在这段代码中,trueSpeed =78 的信息先是交给了class2 的 __init__()方法,然后class2 的 __init__() 方法把 trueSpeed=78 的信息交给了 class1 的 __init__() 方法,从而达成了我们的目的。
回到开始
现在再来看这段代码:
import torch
from torch import nn
class MLP(nn.Module):
def __init__(self,**kwargs):
super(MLP,self).__init__(**kwargs)
self.hidden = nn.Linear(784,256)
self.act = nn.ReLU()
self.output = nn.Linear(256,10)
#所以这个网络的结构是 784 -> 256 -> 10, 中间用 ReLU 激活
def forward(self,x): # x是输入
o = self.act(self.hidden(x))
return self.output(o)
其实很简单,我们只需要记得要是想在子类里使用父类的方法但同时想要指定一些参数的值,就要用这个**kwargs.
在这段代码中,我们自定了一个叫做 MLP 的类,在初始化时会通过 super() 调用 nn.Module 的初始化方法。而 **kwarg 允许我们在这一过程中指定一些参数的值。