torch.nn.Parameter使用举例

参考链接: torch.nn.Parameter
在这里插入图片描述
原文及翻译:

Parameters  
Parameters类型

class torch.nn.Parameter
torch.nn.Parameter类型

    A kind of Tensor that is to be considered a module parameter.
    这个类型是一种特殊的张量Tensor类型,该类型被认为是模块的参数.
    Parameters are Tensor subclasses, that have a very special 
    property when used with Module s - when they’re assigned as Module
    attributes they are automatically added to the list of its 
    parameters, and will appear e.g. in parameters() iterator. 
    Assigning a Tensor doesn’t have such effect. This is because one 
    might want to cache some temporary state, like last hidden state 
    of the RNN, in the model. If there was no such class as Parameter,
    these temporaries would get registered too.
    torch.nn.Parameter 是 Tensor 的子类, 当torch.nn.Parameter 用于模块
    中时具有一个特殊的特性,当它们被赋值给模块的属性时,它们会自动地被添加到
    模块的参数列表中,并且通过parameter()之类的迭代器来访问到它们.相反,如果
    仅仅将一个张量赋值给模型的属性,不会有这样的效果.这是因为有人可能想要缓存
    模型一些临时性状态,比如RNN的最后一个隐藏状态(last hidden state of 
    the RNN).如果没有Parameter这样的类型,那么这些临时数据也会被登记注册.
    
    

    Parameters  参数
            data (Tensor) – parameter tensor.
            data (Tensor张量类型) – 参数的张量.

            requires_grad (bool, optional)if the parameter requires
            gradient. See Excluding subgraphs from backward for more 
            details. Default: True
            requires_grad (布尔类型, 可选) – 如果参数需要求梯度的话就使用
            这个参数.在以下链接[Excluding subgraphs from backward(从backward中排除子图)]
            上可以查看更多详细信息.默认值时True.
            (https://pytorch.org/docs/1.2.0/notes/autograd.html#excluding-subgraphs)
            


解释:

它是Tensor的一个子类,它可以被作为module的参数,把它赋值给module的属性,
那么会自动被添加到module的参数,也就是出现在parameters()迭代器中.
如果不使用Parameter类型,仅仅使用Tensor,
把Tensor类赋值给module的属性,不会有这种效果,
不会出现在parameter()迭代器中.

实验:

import torch
import torch.nn as nn


class Model4CXQ(nn.Module):
    def __init__(self):
        super(Model4CXQ, self).__init__()
        # super().__init__()
        self.attribute4cxq = nn.Parameter(torch.tensor(20200910.0))
        self.attribute4lzq = nn.Parameter(torch.tensor(20200.0))
        # self.attribute4scc = nn.Parameter(torch.Tensor(2.0))  # TypeError: new(): data must be a sequence (got float)
        # self.attribute4pq = nn.Parameter(torch.tensor(2))  # RuntimeError: Only Tensors of floating point dtype can require gradients
        self.attribute4zh = nn.Parameter(torch.Tensor(2))
        # self.attribute4yzb = nn.Parameter(torch.tensor(912.0))
        self.attribute4yzb = (torch.tensor(912.0))
        self.attribute4gcx = (torch.tensor(3))
        self.attribute4ymw = (torch.Tensor(3))

    def forward(self, x):
        pass


if __name__ == "__main__":
    model = Model4CXQ()
    print()
    print("打印参数".center(50,'-'))
    for param in model.parameters():
        print(param)
    print()
    print("打印字典".center(50,'-'))
    for k, v in model.state_dict().items():
        print(k, v)

控制台输出:

Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。

尝试新的跨平台 PowerShell https://aka.ms/pscore6

加载个人及系统配置文件用了 861 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>  & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '56980' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test8.py'

-----------------------打印参数-----------------------
Parameter containing:
tensor(20200910., requires_grad=True)
Parameter containing:
tensor(20200., requires_grad=True)
Parameter containing:
tensor([1.1673e-42, 0.0000e+00], requires_grad=True)

-----------------------打印字典-----------------------
attribute4cxq tensor(20200910.)
attribute4lzq tensor(20200.)
attribute4zh tensor([1.1673e-42, 0.0000e+00])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>
  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值