pytorch模型定义的方式

最重要的是自己根据需要灵活选取模型定义方式
必要的知识回顾
Module 类是 torch.nn 模块里提供的一个模型构造类 (nn.Module),是所有神经⽹网络模块的基
类,我们可以继承它来定义我们想要的模型;
PyTorch模型定义应包括两个主要部分:各个部分的初始化(init);数据流向定义
(forward)
基于nn.Module,我们可以通过Sequential,ModuleList和ModuleDict三种方式定义PyTorch模型。
下面我们就来逐个探索这三种模型定义方式。
Sequential
对应模块为nn.Sequential()。
当模型的前向计算为简单串联各个层的计算时, Sequential 类可以通过更加简单的方式定义模型。它可
以接收一个子模块的有序字典(OrderedDict) 或者一系列子模块作为参数来逐一添加 Module 的实例,
⽽模型的前向计算就是将这些实例按添加的顺序逐⼀计算。
例如:

// A sequential example
// An highlighted block
class MySequential(nn.Module): 
	from collections import OrderedDict 
	def __init__(self, *args): 
	   super(MySequential, self).__init__() 
	   if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个 OrderedDict 
	      for key, module in args[0].items():
	          self.add_module(key, module) # add_module方法会将module添加进 self._modules(一个OrderedDict) 
	   else: # 传入的是一些Module 
	      for idx, module in enumerate(args):
	           self.add_module(str(idx), module) 
	def forward(self, input): # self._modules返回一个 OrderedDict,保证会按照成员添加时的顺序遍历成 
	    for module in self._modules.values():
	        input = module(input)
	        return input

两种方法使用sequential来定义模型

直接排列

使用OrderedDict

缺点:

使用Sequential也会使得模型定义丧失灵活性,比如需要在模
型中间加入一个外部输入时就不适合用Sequential的方式实现。使用时需根据实际需求加以选择。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值