PyTorch模型的定义

#一、Pytorch模型定义方式

1.1模型定义 三种方式

Module类是torch.nn模块里提供的一个模型构造类(nn.module),是所有神经网络模块的基类,可以继承它来定义;模型定义主要包括两个主要部分:各部分的初始化(_init_);数据流向定义(forward);

基于nn.module,我们可以通过Sequencetial,modulelist和ModuleDict 三种方式定义;

1.1.1Sequential

对应模块为nn.Sequential()

当模型的前向计算为简单串联各个层的计算时,Sequential 类可以通过更简单的方式定义模型。它能接收一个子模块的有序字典(OrderedDict)或者一系列子模块作为参数来逐一添加Module的实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算;

Sequential定义模型时只需将模型的层按顺序排列起来即可,根据层名不同,排列方式有两种:

直接排列

import torch.nn as nn

net = nn.Sequential (

         nn.Linear(784,256),

         nn.ReLu(),

         nn.Linear(256,10)

                    )

print(net)

Sequential (

   (0):Linear(in_features=784,out_features=256,bias=True)

   (1):ReLu()

   (2):Linear(in_features=256,out_features=10,bias=True)

   )

使用OrderedDict:

import collections

import torch.nn as nn

net2 = nn.Sequential ( collections.OrderedDict([

        ('fc1',nn.Linear(784,256),

        ('relu1',nn.ReLu()),

        ( 'fc2', nn.Linear(256,10))

       ]))

print(net2)

Sequential (

   (fc1):Linear(in_features=784,out_features=256,bias=True)

   (relu1):ReLu()

   (fc2):Linear(in_features=256,out_features=10,bias=True)

   )

此定义方式简单易读;但也会使模型定义丧失灵活性;

1.1.2ModuleList(模块为nn.ModuleList)

ModuleList接受一个子模块(或层)的列表作为输入,然后也能类似List那样进行append和extend操作。子模块的权重也会自动添加到网络;

net = nn.ModuleList ([ nn.Linear(784,256) ,  nn.ReLu()])

net.append(  nn.Linear(256,10))

print(net[-1])

print(net)

Linear (in_features=256,out_features=10,bias=True)

ModuleList(

   (0):Linear(in_features=784,out_features=256,bias=True)

   (1):ReLu()

   (2):Linear(in_features=256,out_features=10,bias=True)

   )

nn.ModuleList没有定义一个网络,只是将不同模块储存在一起;还需要经过forward函数指定各个层的先后顺序才算完成模型定义;用for循环即可实现。

1.1.3ModuleDict

其与ModuleList作用类似,只是ModuleList能更方便的为神经网络的层添加名称;

#二、利用模块块快速搭建复杂网络

2.1U-net介绍

U-net是分割模型的杰作,在以医学影像为代表的诸多领域有着广泛应用;组成该模块主要有以下几个部分:

1)每个子块内部的两次卷积;(Double Convolution)

2)左侧模型块之间的下采样连接,即最大池化;(Max pooling);

3)Up sampling;

4)输出层处理;(out convolution)

还包括模块之间的横向连接,输入和U-et底部的连接扥计算;

在实现该模块时 不必把每一层按序排列显式写出,应先定义模型块,再定义模型块直接的连接顺序和计算方法。基础部件对应上述四个模型块;(具体实现代码略)

使用写好的模型块,可以方便地组装U-net模型,通过模块化的方式实现代码复用;

#三、PyTorch修改模型(代码太多略敲)

3.1修改模型层

3.2添加额外输入

3.3添加额外输出

#四、模型的保存与读取

4.1模型存储格式

存储模型主要采用pkl,pt,pth三种格式;

4.2模型存储内容

一个PyTorch模型主要包含两个部分:模型结构和权重。其中模型是继承nn.Module的类,权重的数据结构是一个字典(key 是层名,value是权重向量),存储也分为两种形式:存储整个模型和只存储权重;

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值