深入剖析pytorch nn.Module源码

视频推荐:​7、深入剖析PyTorch nn.Module源码_哔哩哔哩_bilibili

PyTorch的nn.Module是一个基类,用于定义神经网络模型。它提供了一些方法和属性,使得我们可以方便地定义、训练和使用神经网络模型。我们定义的模型也该继承这个类(class)

import torch.nn as nn

import torch.nn.functional as F



class Model(nn.Module):

    def __init__(self):

        super().__init__()

        self.conv1 = nn.Conv2d(1, 20, 5)

        self.conv2 = nn.Conv2d(20, 20, 5)



    def forward(self, x):

        x = F.relu(self.conv1(x))

        return F.relu(self.conv2(x))

如上,nn.conv2d将二维子module加入到父module中去

nn.Module的主要功能如下:

1. 初始化:在初始化时,创建了一些有序字典,用于存储模型的参数、缓冲区、子模块、反向和前向钩子,并设置了模型的训练状态。

2. \_\_call\_\_方法:该方法实现了模型的前向传播。在前向传播之前,会调用注册的前向钩子函数;在前向传播之后,会调用注册的反向钩子函数。

3. forward方法:该方法是一个抽象方法,需要用户根据具体的模型结构来实现。在该方法中定义了模型的前向传播过程。

4. register_buffer方法:用于注册缓冲区,缓冲区是一些不需要进行梯度更新的张量,比如移动平均值。

5. register_parameter方法:用于注册模型的参数,参数是需要进行梯度更新的张量。

6. add_module(name, module)方法:用于添加子模块。

将子模块添加到当前模块。

可以使用给定名称将模块作为属性进行访问。

参数:

namestr) – 子模块的名称。子模块可以是 使用给定名称从此模块访问

module(模块)– 要添加到模块的子模块。

7. apply方法:递归地对模型及其子模块应用一个函数。

        一般是初始化一个模型的参数的,示例如下


@torch.no_grad()

def init_weights(m):  #此处m类型就是个module类型

    print(m)

    if type(m) == nn.Linear:

        m.weight.fill_(1.0)

        print(m.weight)

net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))

net.apply(init_weights)

#Linear(in_features=2, out_features=2, bias=True)

#Parameter containing:

#tensor([[1., 1.],

#        [1., 1.]], requires_grad=True)

#Linear(in_features=2, out_features=2, bias=True)

#Parameter containing:

#tensor([[1., 1.],

#        [1., 1.]], requires_grad=True)

#Sequential(

#  (0): Linear(in_features=2, out_features=2, bias=True)

#  (1): Linear(in_features=2, out_features=2, bias=True)

#)

8. cuda方法:将模型及其子模块的参数和缓冲区移动到GPU上。

9. cpu方法:将模型及其子模块的参数和缓冲区移动到CPU上。

10. \_apply方法:递归地对模型及其子模块的参数和缓冲区应用一个函数。

11. state_dict方法:返回模型的状态字典,包含模型的参数和缓冲区。

12. load_state_dict方法:加载模型的状态字典,用于恢复模型的参数和缓冲区。

通常用torch.self保存多个函数,其中就含有当前参数和buffer,就可以用该函数导入参数和buffer

13. parameters方法:返回模型的参数。

14. named_parameters方法:返回模型的参数及其名称。

15. buffers方法:返回模型的缓冲区。


for buf in model.buffers():

    print(type(buf), buf.size())

#<class 'torch.Tensor'> (20L,)

#<class 'torch.Tensor'> (20L, 1L, 5L, 5L)

16. named_buffers方法:返回模型的缓冲区及其名称。

17. children方法:返回模型的子模块。

18. named_children方法:返回模型的子模块及其名称。

19. train方法:设置模型的训练状态。

20. eval方法:设置模型的评估状态。

21. zero_grad方法:将模型的参数的梯度置零。

22. \_\_setattr\_\_方法:用于设置模型的属性。

23. \_\_getattr\_\_方法:用于获取模型的属性。 

如何保存数据

使用pytorch save checkpoint

# 1. Import necessary libraries for loading our data
# For this recipe, we will use torch and its subsidiaries torch.nn and torch.optim.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
# 2. Define and initialize the neural network
# For sake of example, we will create a neural network for training images. To learn more see the Defining a Neural Network recipe.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)
# 3. Initialize the optimizer
# We will use SGD with momentum.

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 4. Save the general checkpoint
# Collect all relevant information and build your dictionary.

# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({
            'epoch': EPOCH,  #定义训练周期
            'model_state_dict': net.state_dict(),  #做测试,网络所有参数和buffer存放地
            'optimizer_state_dict': optimizer.state_dict(), #优化器所有参数和buffer存放地
            'loss': LOSS, #模型平均函数损失值
            }, PATH)
# 5. Load the general checkpoint
# Remember to first initialize the model and optimizer, then load the dictionary locally.

model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict']) #读取
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])#读取
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# # - or -
# model.train()
# You must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.
#
# If you wish to resuming training, call model.train() to ensure these layers are in training mode.
#
# Congratulations! You have successfully saved and loaded a general checkpoint for inference and/or resuming training in PyTorch.
#
# Total running time of the script: ( 0 minutes 0.000 seconds)

训练结果如下:

 了解pytorch parameter

可以跳转PyTorch中的torch.nn.Parameter() 详解_Adenialzz的博客-CSDN博客,详细了解

以下为官方介绍

torch.nn.parameter.Parameter(data=Nonerequires_grad=True)[SOURCE]

A kind of Tensor that is to be considered a module parameter.

Parameters are Tensor subclasses, that have a very special property when used with Module s - when they’re assigned as Module attributes they are automatically added to the list of its parameters, and will appear e.g. in parameters() iterator. Assigning a Tensor doesn’t have such effect. This is because one might want to cache some temporary state, like last hidden state of the RNN, in the model. If there was no such class as Parameter, these temporaries would get registered too.

Parameters:

深入体会register_parameter以及与parameter的不同

Python网络训练 torch.nn.Parameter和register_parameter的用法和例子_longjiaxin1314的博客-CSDN博客

以上为学习笔记,若有错误请指正

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值