Pytorch框架学习 -2 torch.nn.modules.Module(nn.Module)理解

本文深入探讨PyTorch框架中的nn.Module类,包括其基本参数、初始化函数、forward方法、注册器以及如何在实际模型中应用。通过简单和复杂的例子,解释了Module类在模型定义和训练过程中的核心作用。
摘要由CSDN通过智能技术生成

Pytorch框架学习 -2 torch.nn.modules.Module(nn.Module)理解

最简单的例子

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

分析

  1. 一个Pytroch模型应该以类的形式出现
  2. Pytorch训练模型应该是nn.Module的子类
  3. 一个训练模型包含经过初始化和前向传播两个过程

初始化模型是为了注册参数,保证模型能够正常处理这些重要参数,显然是必要
不同神经网络的前向传播过程肯定要自己定义,否则这个模型就失去了独特性

部分源码:

基本参数

class Module:
    dump_patches: bool = False
    _version: int = 1
    training: bool
dump_patches

当调用.to()|.cuda()的时候,将参数也将转化为gpu类型

_version

用于之后函数比较版本

training

使用train(mode)方法时修改,默认在init时变为True

主要影响bn和dropout等在网络训练和评估时使用方法不一样的功能

初始化函数

    def __init__(self):
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._non_persistent_buffers_set = set()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
_parameters

保存当前module的训练参数

_buffers

保存当前moduile的非训练参数

_modu
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值