pytorch模型定义
module 类别是torch.nn 提供的模型构造类,(nn.module)所有神经网络的基类,可以继承他来定义我们想要的模型。
pytorch模型定义应包括两个部分:初始化__init__;数据流向定义forward。
.基于nn.module 我们通过sequential,moduleList,ModuleDict三个方式来定义Pytorch模型
一:基础层搭建
- sequential
前向计算为简单串联各个层时,sequential类可以更加简单定义模型。
例子
class Mysequential(nn.module):
from collections import OrdereDict
def __init__(self,*args):
super(Mysequential,self).init__()
if len(args)==1 and isinstance(args[0],OrderedDict):
for key, module in args[0].items():
self.add_module(key,module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx),module)
def forward(self,input):
for module in self.modules.values():
input=module(input