09模型创建步骤与nn.Module

本文详细介绍了网络模型的创建步骤,包括数据准备、模型定义、损失函数和优化器的选择,以及通过LeNet模型进行示例说明。此外,深入探讨了nn.Module在PyTorch中的作用,它作为神经网络的基础模块,包含parameters、modules、buffers等属性,以及如何管理子模块和实现forward()函数。
摘要由CSDN通过智能技术生成

一、网络模型创建步骤

1.1 模型训练步骤

  • 数据
  • 模型
  • 损失函数
  • 优化器
  • 迭代训练

1.2 模型创建步骤

在这里插入图片描述

1.3 模型构建两要素:

在这里插入图片描述

1.4 模型创建示例——LeNet

LeNet模型结构图:
在这里插入图片描述
LeNet计算图:
在这里插入图片描述
LeNet模型部分代码:

class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

说明:在__init__()中实现子模块的构建,在forward中实现子模块的拼接,所以在pytorch中,前向传播过程就是子模块的拼接过程

二、nn.Module

2.1 torch.nn

torch.nn:Pytroch中的神经网络模块,主要包括以下四个子模块
在这里插入图片描述

2.2 nn.Module

  • parameters: 存储管理nn.Parameter类
  • modules : 存储管理nn.Module类
  • buffers: 存储管理缓冲属性, 如BN层中的running_mean
  • ***_hooks: 存储管理钩子函数

8个有序字典:
在这里插入图片描述

说明: 自定义模型时,init()方法会继承父类nn.Module的__init__()函数

class LeNet(nn.Module):
    def __init__(self, classes):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

在nn.module的__init__()方法中,主要是_construct()

    def __init__(self):
        self._construct()
        # initialize self.training separately from the rest of the internal
        # state, as it is managed differently by nn.Module and ScriptModule
        self.training = True

_construct()创建了8个有序字典

    def _construct(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()

说明: 自定义模型和其他网络层模块都直接或间接的继承nn.Module,所以都会有这8个有序字典,而其子模块,参数或者其他属性会存储在相应的有序字典中

2.3 nn.Module总结

  • 一个module可以包含多个子module
  • 一个module相当于一个运算,必须实现forward()函数
  • 每个module都有8个字典管理它的属性
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值