Pytorch—模型构造(Module)
1. 模型构造(Module类)
1.1. Module简介
在这个系列之前的文章中,在自定义模型的时候,始终都都需要继承nn.Module类。这里我们对于这个类进行一下解释。
众所周知,Pytorch是基于动态图的模型搭建方式,我们可以随机的在网络中添加或者删除网络层。在搭建我们自己的网络结构的时候,我们需要基础nn.Module类作为父类。然后在我们自定义类的内部添加不同的网络层。其中nn.Module类是nn模块中提供的一个模型构造类,是所有神经网络模块的基础类。我们需要继承它来定义我们的自己的网络结构。在自定义网络结构的时候,我们需要对于Module类内的"_init_"方法和“forward”方法进行重载。它们的作用是具体的定义我们的网络结构和相关参数。以及定义网络的前向传播方式。举个例子来说:
import torch
from torch import nn
class MLP(nn.Module):
def __init__(self,**kwargs):
'''
**kwargs :自定义网络结构需要的参数
'''
super(MLP,self).__init__(**kwargs) #初始化父类
self.hidden = nn.Linear(k,q) #定义一个输入维度为k,输出维度为q的线性隐藏层
self.act = nn.ReLU() # 定义激活函数
self.output = nn.Linear(q,m) #定义一个输入维度为q,输出维度为m的输出线性层
def forward(self,inputs):
'''
定义前向传播过程
inputs: 输入数据
'''
a = self.act(self.hidden(inputs))
return self.output(a)
上面定义的是一个简单的感知机网络,在自定义的网络中,我们不需要自定义反向传播函数。系统将通过自动求梯度而自动生成反向传播需要的backward函数。
在定义完网络结构之后,我们需要对网络结构进行实例化。举上面的例子来说
X = torch.rand(2,784) #生成随机数据
net = MLP() #实例化网络结构
print(net)
net(X) #对网络进行前向传播
这里在将MLP实例化成net之后,通过Module类的"_call_"函数,可以直接使用“net(X)”来对数据进行前向传播(调用forward函数)。
2. Module的子类
2.1 Sequential子类
这个子类的基本使用我们在之前的文章中已经介绍了,现在我们来回顾一下这个子类的使用。首先,这是一个容器类,当模型的前向计算为简单的串联(堆叠网络层)的时候,可以通过Sequential类以更加简单的方式来定义模型。Sequential可以接收一个子模块的有序词典或者一系列的子模块作为参数来逐一的添加Module的子类的实例。在前向计算的时候,可以将这些实例按照添加的顺序之一计算,向前传播。这里实现一个MySequential类,其机制和Sequential类似。
#encoding=utf-8
import torch
import torch.nn as nn
from collections import OrderedDict
class MySequential(nn.Module):
def __init__(self,*args):
super(MySequential,self).__init__()
if len(args) == 1 and isinstance(args[0],OrderedDict): # 如果传入的是一个有序的dict
for key,module in args[0].items():
self.add_module(key,module)
else:
for idx,module in enumerate(args[0]): #传入的一些Module
self.add_module(str(idx),module)
def forward(self,input):
# self._modules 返回一个OrderedDict,保证能够按照成员添加时的顺序进行遍历
for module in self._modules.values():
input = module(input)
return input
if __name__ == '__main__':
X = torch.rand(2,784)
net = MySequential([nn.Linear(784,256),nn.ReLU(),nn.Linear(256,19)])
print(net)
net(X)
2.2 ModuleList类
ModuleList类接收一个子模块的列表作为输入,然后也可以类似List那样进行append和extend操作。类似于我们建立一个list,list内部中的每一个元素一个网络层。举一个例子来说。
net = nn.ModuleList([nn.Linear(784,256),nn.ReLU()])
net.append(nn.Linear(256,10))
print(net[-1])
print(net)
这里值得注意的是,无论是Sequential和ModuleList都可以以list的形式来构建复杂的网络结构。但是二者是有一定区别的,首先Sequential是可以表示一种顺序结构的,这种顺序结构会使得前一个层的输出作为下一个层的输入。而ModuleList仅仅是存储各个模块的list,这些模块之间没有联系也没有顺序。在Sequential中需要注意前后两层的输出和输入的维度。其次,在进行forward的时候,Sequential自动定义了按照顺序的forward的过程。而ModuleList需要自定义网络应该如何进行前向传播。我们举一个例子:
import torch
import torch.nn as nn
class MyModuleList(nn.Module):
def __init__(self):
super(MyModuleList,self).__init__()
self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(10)])
def forward(self,x):
for i,l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
可以看到,这里的前向传播方式没有项Sequential顺序传播,这里使用的是我们自己定义的传播方式。
与此同时,ModuleList中的每一个网络层的参数都会被加入到整个网络结构中,可以使用Optimizer进行优化,而使用Python定义的list是不会一起将所有参数进行添加的。举一个例子:
class Module_ModuleList(nn.Module):
def __init__(self):
super(Module_ModuleList,self).__init__()
self.linears = nn.ModuleList([nn.Linear(10,10)])
class Module_List(nn.Module):
def __init__(self):
super(Module_List,self)
self.linears = [nn.Linear(10,10)]
if __name__ == '__main__':
net1 = Module_ModuleList()
net2 = Module_List()
print("net1: ")
for p in net1.parameters():
print(p.size())
print("net2 : ")
for p in net2.parameters():
print(p)
(这里第二种输出会报错,没有参数)。而第一种输出会将List内的所有参数输出。
2.3 ModuleDict子类
ModuleDict接收一个子模块的字典作为输入,然后按照类似于字典的形式添加访问操作,举一个常见的例子:
net = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
和ModuleList类似的是,ModuleDict实例仅仅是存放了一些模块的字典,并没有定义forward函数,前向传播的方式需要我们自己定义。同样,ModuleDict也会自动的将内部的参数添加到网络结构的内部。
3 构造一个复杂的模型实例
下面的我们来构建一个稍微复杂的实例来总结一下上面的内容。
#encoding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
class FancyMLP(nn.Module):
def __init__(self):
super(FancyMLP,self).__init__()
#这里需要注意,rand_weight不是参数,是一个常量
self.rand_weight = torch.rand((20,20),requires_grad=False)
self.linear = nn.Linear(20,20)
def forward(self,x):
x = self.linear(x)
x = F.relu(torch.mm(x,self.rand_weight.data)+1)
x = self.linear(x)
while x.norm().item() > 1:
x / 2
if x.norm().item() < 0.8:
x *= 10
return x.sum()
#定义网络叠加结构
class NestMLP(nn.Module):
def __init__(self,**kwargs):
super(NestMLP,self).__init__()
self.net = nn.Sequential(nn.Linear(40,30),nn.ReLU())
def forward(self,x):
return self.net(x)
net = nn.Sequential(NestMLP(),nn.Linear(30,20),FancyMLP())
if __name__ == '__main__':
X = torch.rand(2, 20)
print(net)
net(X)
4. 参考
- 动手学深度学习—Pytorch版