深度之眼Pytorch框架训练营第四期——模型创建与nn.Module

模型创建与nn.Module

1、模型创建步骤

模型的创建示意图如下:
在这里插入图片描述
从上图中可以看出,模型的创建与权值初始化共同构成了模型,模型的创建只要包括了:

  • 构建网络层:卷积层,池化层,激活函数等;
  • 拼接网络层:网络层有构建网络层后,需要进行网络层的拼接,拼接成 L e N e t LeNet LeNet A l e x N e t AlexNet AlexNet R e s N e t ResNet ResNet
    创建好模型后,需要对模型进行权值初始化,PyTorch中的初始化方法主要有:XavierKaiming,均匀分布,正态分布等方法。
2、nn.Module
  • 第一部分中讲到的模型的创建权值初始化PyTorch中均需要通过nn.Module来完成,nn.Module是整个模块的根基
  • nn.Moduletorch.nn中的模块,torch.nn中一共有四个模块,如下图所示:

在这里插入图片描述

  • nn.Module中有八个重要的属性用于管理整个模型:
  • parameters: 存储管理nn.Parameter
  • modules:存储管理nn.Module
  • buffers:存储管理缓冲属性,如BN层中的running_mean
  • ***_hooks:共有5个,存储管理钩子函数
self._parameters = OrderedDict()
self._buffers = OrderedDict() 
self._backward_hooks = OrderedDict() 
self._forward_hooks = OrderedDict() 
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict() 
self._load_state_dict_pre_hooks = OrderedDict() 
self._modules = OrderedDict()
3、以LeNet模型为例探究nn.Module
  • 如图所示,LeNet由很多网络层构成,包括两个卷积层,两个池化层和三个全连接层LeNet: Conv1 -> pool1 -> Conv2 -> pool2 -> fc1 -> fc2 -> fc3

在这里插入图片描述

  • 将上图转为一个计算图的形式,如下图所示,计算图有两个主要的概念:一个是节点一个是边,节点就是张量数据,边就是运算,在图中就是箭头
    在这里插入图片描述
  • 构建模型有两要素,第一是构建子模块,比如LeNet是由很多网络层构成的,所以首先得构建子模块中的网络层;构建好网络层后,第二是拼接子模块,按照一定拓扑结构拼接子模块就可以得到模型,构建子模块需要用到__init__()函数,而拼接子模块需要用到forward()函数,下面针对这两个函数进行讲解
(1)初始化部分:__init__()
class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()    # 继承父类nn.Module的初始化
        self.conv1 = nn.Conv2d(3, 6, 5)    # 卷积层,卷积核为5*5,输入通道为3,输出通道为6
        self.conv2 = nn.Conv2d(6, 16, 5)    # 卷积层
        self.fc1 = nn.Linear(16*5*5, 120)c    # 全连接层
        self.fc2 = nn.Linear(120, 84)	# 全连接层
        self.fc3 = nn.Linear(84, classes)	# 全连接层

#####(2)拼接部分:forward()

def forward(self, x):
    out = F.relu(self.conv1(x))  # import torch.nn.functional as F
    out = F.max_pool2d(out, 2)
    out = F.relu(self.conv2(out))
    out = F.max_pool2d(out, 2)
    out = out.view(out.size(0), -1)
    out = F.relu(self.fc1(out))
    out = F.relu(self.fc2(out))
    out = self.fc3(out)
    return out
(3)nn.Module的属性构建

nn.Module的属性构建会在module类中进行属性赋值的时候会被setattr()函数拦截,在这个函数当中会判断即将要赋值的数据类型是否是nn.parameters类,如果是的话就会存储到parameters字典中;如果是module类就会存储到modul字典中

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值