torch.nn.Module.parameters()和torch.nn.Module.state_dict()的使用举例

链接: torch.nn.Module.parameters(recurse=True)
链接: torch.nn.Module.state_dict(destination=None, prefix=’’, keep_vars=False)

在这里插入图片描述

文档翻译:

parameters(recurse=True)
	Returns an iterator over module parameters.
	返回一个迭代器,该迭代器可以遍历模块的参数.
	This is typically passed to an optimizer.
	通常用该方法将参数传递给优化器.
	
	Parameters 参数
    recurse (bool)if True, then yields parameters of this 
    module and all submodules. Otherwise, yields only parameters
    that are direct members of this module.
    recurse (bool类型) - 如果rescue是True,那么yield迭代返回出这个模块
    以及该模块的所有子模块的参数. 否则,如果是False,那么只yield迭代返回
    出这个模块的直接成员.
    Yields 迭代返回
    Parameter – module parameter
    Parameter类型 - 模块的参数
    
	Example: 例子:

	>>> for param in model.parameters():
	>>>     print(type(param.data), param.size())
	<class 'torch.FloatTensor'> (20L,)
	<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)

在这里插入图片描述

文档翻译:

state_dict(destination=None, prefix='', keep_vars=False)
	Returns a dictionary containing a whole state of the module.
	返回一个字典,该字典包含了模块的整个状态信息.
	Both parameters and persistent buffers (e.g. running averages)
	are included. Keys are corresponding parameter and buffer names.
	参数和持续性缓冲(比如:running averages)都会包含在该字典内.字典的关键字
	对应了参数和缓冲的名字.
	Returns 返回
    	a dictionary containing a whole state of the module
    	包含了模块的整个状态信息的字典.
    Return type 返回类型
    	dict 字典类型

	Example: 举例:

	>>> module.state_dict().keys()
	['bias', 'weight']

代码实验举例:

import torch
import torch.nn as nn


class Model4CXQ(nn.Module):
    def __init__(self):
        super(Model4CXQ, self).__init__()
        # super().__init__()
        self.attribute4cxq = nn.Parameter(torch.tensor(20200910.0))
        self.attribute4lzq = nn.Parameter(torch.tensor(20200.0))
        # self.attribute4scc = nn.Parameter(torch.Tensor(2.0))  # TypeError: new(): data must be a sequence (got float)
        # self.attribute4pq = nn.Parameter(torch.tensor(2))  # RuntimeError: Only Tensors of floating point dtype can require gradients
        self.attribute4zh = nn.Parameter(torch.Tensor(2))
        # self.attribute4yzb = nn.Parameter(torch.tensor(912.0))
        self.attribute4yzb = (torch.tensor(912.0))
        self.attribute4gcx = (torch.tensor(3))
        self.attribute4ymw = (torch.Tensor(3))

    def forward(self, x):
        pass


if __name__ == "__main__":
    model = Model4CXQ()
    print()
    print("打印参数".center(50,'-'))
    for param in model.parameters():
        print(param)
    print()
    print("打印字典".center(50,'-'))
    for k, v in model.state_dict().items():
        print(k, v)

    print()
    print("增加属性".center(50,'-'))
    attribute4cjhT = torch.Tensor(3)
    attribute4cjhP = nn.Parameter(torch.Tensor(2))
    model.attribute4cjhT = attribute4cjhT
    model.attribute4cjhP = attribute4cjhP
    print("打印字典".center(50,'-'))
    for k, v in model.state_dict().items():
        print(k, v)


    print()
    print("注册属性".center(50,'-'))
    attribute4cjhRT = torch.Tensor(3)
    attribute4cjhRP = nn.Parameter(torch.Tensor(2))
    # model.register_parameter('attribute4cjhRT', attribute4cjhRT)  
    # 上一行代码报错,注册的应该是Parameter类型,而不是FloatTensor类型
    # 报错信息:
    # TypeError: cannot assign 'torch.FloatTensor' object to parameter 'attribute4cjhRT' (torch.nn.Parameter or None required)
    model.register_parameter('登记注册attribute4cjhRP属性', attribute4cjhRP)

    print("打印字典".center(50,'-'))
    for k, v in model.state_dict().items():
        print(k, v)   

控制台输出:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 892 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '51788' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test22.py'

-----------------------打印参数-----------------------
Parameter containing:
tensor(20200910., requires_grad=True)
Parameter containing:
tensor(20200., requires_grad=True)
Parameter containing:
tensor([1.4038e+13, 8.6419e+00], requires_grad=True)

-----------------------打印字典-----------------------
attribute4cxq tensor(20200910.)
attribute4lzq tensor(20200.)
attribute4zh tensor([1.4038e+13, 8.6419e+00])

-----------------------增加属性-----------------------
-----------------------打印字典-----------------------
attribute4cxq tensor(20200910.)
attribute4lzq tensor(20200.)
attribute4zh tensor([1.4038e+13, 8.6419e+00])
attribute4cjhP tensor([0., 0.])

-----------------------注册属性-----------------------
-----------------------打印字典-----------------------
attribute4cxq tensor(20200910.)
attribute4lzq tensor(20200.)
attribute4zh tensor([1.4038e+13, 8.6419e+00])
attribute4cjhP tensor([0., 0.])
登记注册attribute4cjhRP属性 tensor([0., 0.])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> 
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值