最近在改写网络的过程中,发现自己对于super()继承的概念掌握的不清楚,引发了网络训练的相关似是而非的问题,有意思的是到目前为止这些问题仍然有待于理清。
另外本次完成关于深度学习网络的第一次改写,整体还是比较顺利的,但是也暴露出对于深度学习库相关语法的不熟悉,有待进一步加强掌握。
1、super及其背后的类继承方法
1.1 super与单继承、多继承
super() 函数是用于调用父类(超类)的一个方法。而且:Python3.x 和 Python2.x 的一个显著区别是:Python 3 可以使用直接使用 super().xxx 代替 super(Class, self).xxx。
super 是用来解决多重继承问题的,直接用类名调用父类方法在使用单继承的时候没问题,但是如果使用多继承,会涉及到查找顺序(MRO)、重复调用(钻石继承)等种种问题。我们最为常见的是super调用父类的init()函数,但是实际上super可以调用父类任意的已定义的函数。
比如说直接调用父类:
class Base:
def __init__(self):
print('Base.__init__')
class A(Base):
def __init__(self):
Base.__init__(self)
print('A.__init__')
以上方法在以上简单的类继承问题中没有问题,但是涉及到多继承时可能会遭遇到多次初始化基类的问题。
class Base:
def __init__(self):
print('Base.__init__')
class A(Base):
def __init__(self):
Base.__init__(self)
print('A.__init__')
class B(Base):
def __init__(self):
Base.__init__(self)
print('B.__init__')
class C(A,B):
def __init__(self):
A.__init__(self)
B.__init__(self)
print('C.__init__')
在以上示例中,会发现base类遇到两次初始化过程,而super函数的存在就是为了解决此类问题,避免基类的重复调用。
MRO 就是类的方法解析顺序表, 其实也就是继承父类方法时的顺序表。上面程序打印的结果是:
c = C()
print结果:
Base.__init__
B.__init__
A.__init__
C.__init__
涉及到的继承方法就是钻石继承问题(MRO/Diamond Inheritance):
以上的类列表总结之后可以写作:
print(C.__mro__)
(<class '__main__.C'>, <class '__main__.A'>, <class '__main__.B'>,
<class '__main__.Base'>, <class 'object'>)
一句话总结:父类总是出现在子类后,若有多个父类,其相对顺序保持不变,在定义子类调参时,父类的顺序就已经决定了MRO顺序。结合C++函数内部的构造函数与析构函数的调用顺序:构建一个子类对象时,先执行父类的构造函数,再执行子类的构造函数;析构一个子类对象时,先执行子类的析构函数,再执行父类的析构函数。则不难理解为什么以上的C初始化顺序为:Base-B-A-C;我们理解一下按照这个顺序的函数覆盖历程,也就是相同命名的函数也会是按照C-A-B-Base的覆盖顺序进行函数的覆盖与定义操作。
整个寻找父类顺序表的过程有点类似于二叉树的深度优先的非重复检索历程
因此:MRO列表中的类顺序会让你定义的任意类层级关系变得有意义。当你使用 super() 函数时,Python会在MRO列表上继续搜索下一个类。 只要每个重定义的方法统一使用 super() 并只调用它一次, 那么控制流最终会遍历完整个MRO列表,每个方法也只会被调用一次。
1.2 super牵扯到的实例self
class A:
def __init__(self):
self.n = 2
def add(self, m):
print('self is {0} @A.add'.format(self))
self.n += m
class B(A):
def __init__(self):
self.n = 3
def add(self, m):
print('self is {0} @B.add'.format(self))
super().add(m)
self.n += 3
b = B()
b.add(2)
print(b.n):
self is <__main__.B object at 0x106c49b38> @B.add
self is <__main__.B object at 0x106c49b38> @A.add
8
以上代码说明两个问题
1、super().add(m) 确实调用了父类 A 的 add 方法。
2、super().add(m) 调用父类方法 def add(self, m) 时, 此时父类中 self 并不是父类的实例而是子类的实例, 所以 super().add(m) 之后 self.n 的结果是 5 而不是 4 。
需要尤其注意到实例化的子类方法,self是类的本身实例变量,self.function时是首先调用自身的方法,如果自身没有再去父类中找;super是直接从父类中找方法,同时按照MRO顺序不断地找到最接近现定义类的方法。
类的方法(函数)和普通函数没有区别,可以用默认参数、可变参数或者关键字参数(*args是可变参数,args接收的是一个tuple,**kw是关键字参数,kw接收的是一个dict)。
2、深度学习网络的改写历程
惭愧,炼丹一年以来,对于亲手搭建网络架构的事情做的其实不多;自己更多涉略的是loss修改、CV代码修改、框架代码搭建、数据集finetune或者小修小改网络架构,对于网络架构搭建或者另起炉灶涉略的确实不多。
2.1 关于init 与 forward对照
之前自己认为init部分涉及到的网络架构需要与forward部分完全对应起来,但是实际上真正的网络结构是以forward部分为准,而并不是以init部分为准。
from torch import nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(20, 120)
self.fc2 = nn.Linear(120, 64)
self.fc3 = nn.Linear(64, 1)
self.fc4 = nn.Linear(1, 1)
self.drop = nn.Dropout(0.3)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.drop(x)
x = self.fc3(x)
return x
if __name__ == '__main__':
net = Net()
print(net)
with or without self.fc4的情况下,forward部分不变动,但是打印出来的net网络结构不同。因此我们实际上的网络结构是以forward部分定义为准的。而init部分因为网络定义都是继承自nn.module,所以直接打印网络模型的时候就是打印nn.module的结构。
如果想要冻结部分训练模型,个人感觉可以在forward部分使用 with torch.no_grad():
,或者在init()函数定义部分使用with torch.no_grad()
可以冻结部分模型参数,大概就是在Fastai中可以见到的freeze()函数。
参照Fastai的原代码,其实可以学习到整个freeze()和unfreeze()代码规范。
def freeze_to(self, n:int)->None:
"Freeze layers up to layer group `n`."
if hasattr(self.model, 'reset'): self.model.reset()
for g in self.layer_groups[:n]:
for l in g:
if not self.train_bn or not isinstance(l, bn_types): requires_grad(l, False)
for g in self.layer_groups[n:]: requires_grad(g, True)
self.create_opt(defaults.lr)
def freeze(self)->None:
"Freeze up to last layer group."
assert(len(self.layer_groups)>1)
self.freeze_to(-1)
def unfreeze(self):
"Unfreeze entire model."
self.freeze_to(0)
由此延伸开来的知识点非常多,包括forward由__call__函数调用,以及torch内部关于nn.module的使用,我贴一些经典的CSDN报告在本文章下面。
nn.Modlue及nn.Linear 源码理解
python class 中 的__call__方法
前向传播函数forward
nn.Sequential讲解
一言以蔽之,Python中有一个有趣的语法,只要定义类型的时候,实现__call__函数,这个类型就成为可调用的。