参考链接: torch.nn.Parameter
原文及翻译:
Parameters
Parameters类型
class torch.nn.Parameter
torch.nn.Parameter类型
A kind of Tensor that is to be considered a module parameter.
这个类型是一种特殊的张量Tensor类型,该类型被认为是模块的参数.
Parameters are Tensor subclasses, that have a very special
property when used with Module s - when they’re assigned as Module
attributes they are automatically added to the list of its
parameters, and will appear e.g. in parameters() iterator.
Assigning a Tensor doesn’t have such effect. This is because one
might want to cache some temporary state, like last hidden state
of the RNN, in the model. If there was no such class as Parameter,
these temporaries would get registered too.
torch.nn.Parameter 是 Tensor 的子类, 当torch.nn.Parameter 用于模块
中时具有一个特殊的特性,当它们被赋值给模块的属性时,它们会自动地被添加到
模块的参数列表中,并且通过parameter()之类的迭代器来访问到它们.相反,如果
仅仅将一个张量赋值给模型的属性,不会有这样的效果.这是因为有人可能想要缓存
模型一些临时性状态,比如RNN的最后一个隐藏状态(last hidden state of
the RNN).如果没有Parameter这样的类型,那么这些临时数据也会被登记注册.
Parameters 参数
data (Tensor) – parameter tensor.
data (Tensor张量类型) – 参数的张量.
requires_grad (bool, optional) – if the parameter requires
gradient. See Excluding subgraphs from backward for more
details. Default: True
requires_grad (布尔类型, 可选) – 如果参数需要求梯度的话就使用
这个参数.在以下链接[Excluding subgraphs from backward(从backward中排除子图)]
上可以查看更多详细信息.默认值时True.
(https://pytorch.org/docs/1.2.0/notes/autograd.html#excluding-subgraphs)
解释:
它是Tensor的一个子类,它可以被作为module的参数,把它赋值给module的属性,
那么会自动被添加到module的参数,也就是出现在parameters()迭代器中.
如果不使用Parameter类型,仅仅使用Tensor,
把Tensor类赋值给module的属性,不会有这种效果,
不会出现在parameter()迭代器中.
实验:
import torch
import torch.nn as nn
class Model4CXQ(nn.Module):
def __init__(self):
super(Model4CXQ, self).__init__()
# super().__init__()
self.attribute4cxq = nn.Parameter(torch.tensor(20200910.0))
self.attribute4lzq = nn.Parameter(torch.tensor(20200.0))
# self.attribute4scc = nn.Parameter(torch.Tensor(2.0)) # TypeError: new(): data must be a sequence (got float)
# self.attribute4pq = nn.Parameter(torch.tensor(2)) # RuntimeError: Only Tensors of floating point dtype can require gradients
self.attribute4zh = nn.Parameter(torch.Tensor(2))
# self.attribute4yzb = nn.Parameter(torch.tensor(912.0))
self.attribute4yzb = (torch.tensor(912.0))
self.attribute4gcx = (torch.tensor(3))
self.attribute4ymw = (torch.Tensor(3))
def forward(self, x):
pass
if __name__ == "__main__":
model = Model4CXQ()
print()
print("打印参数".center(50,'-'))
for param in model.parameters():
print(param)
print()
print("打印字典".center(50,'-'))
for k, v in model.state_dict().items():
print(k, v)
控制台输出:
Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。
尝试新的跨平台 PowerShell https://aka.ms/pscore6
加载个人及系统配置文件用了 861 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '56980' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test8.py'
-----------------------打印参数-----------------------
Parameter containing:
tensor(20200910., requires_grad=True)
Parameter containing:
tensor(20200., requires_grad=True)
Parameter containing:
tensor([1.1673e-42, 0.0000e+00], requires_grad=True)
-----------------------打印字典-----------------------
attribute4cxq tensor(20200910.)
attribute4lzq tensor(20200.)
attribute4zh tensor([1.1673e-42, 0.0000e+00])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>