链接: torch.nn.Module.parameters(recurse=True)
链接: torch.nn.Module.state_dict(destination=None, prefix=’’, keep_vars=False)
文档翻译:
parameters(recurse=True)
Returns an iterator over module parameters.
返回一个迭代器,该迭代器可以遍历模块的参数.
This is typically passed to an optimizer.
通常用该方法将参数传递给优化器.
Parameters 参数
recurse (bool) – if True, then yields parameters of this
module and all submodules. Otherwise, yields only parameters
that are direct members of this module.
recurse (bool类型) - 如果rescue是True,那么yield迭代返回出这个模块
以及该模块的所有子模块的参数. 否则,如果是False,那么只yield迭代返回
出这个模块的直接成员.
Yields 迭代返回
Parameter – module parameter
Parameter类型 - 模块的参数
Example: 例子:
>>> for param in model.parameters():
>>> print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
文档翻译:
state_dict(destination=None, prefix='', keep_vars=False)
Returns a dictionary containing a whole state of the module.
返回一个字典,该字典包含了模块的整个状态信息.
Both parameters and persistent buffers (e.g. running averages)
are included. Keys are corresponding parameter and buffer names.
参数和持续性缓冲(比如:running averages)都会包含在该字典内.字典的关键字
对应了参数和缓冲的名字.
Returns 返回
a dictionary containing a whole state of the module
包含了模块的整个状态信息的字典.
Return type 返回类型
dict 字典类型
Example: 举例:
>>> module.state_dict().keys()
['bias', 'weight']
代码实验举例:
import torch
import torch.nn as nn
class Model4CXQ(nn.Module):
def __init__(self):
super(Model4CXQ, self).__init__()
# super().__init__()
self.attribute4cxq = nn.Parameter(torch.tensor(20200910.0))
self.attribute4lzq = nn.Parameter(torch.tensor(20200.0))
# self.attribute4scc = nn.Parameter(torch.Tensor(2.0)) # TypeError: new(): data must be a sequence (got float)
# self.attribute4pq = nn.Parameter(torch.tensor(2)) # RuntimeError: Only Tensors of floating point dtype can require gradients
self.attribute4zh = nn.Parameter(torch.Tensor(2))
# self.attribute4yzb = nn.Parameter(torch.tensor(912.0))
self.attribute4yzb = (torch.tensor(912.0))
self.attribute4gcx = (torch.tensor(3))
self.attribute4ymw = (torch.Tensor(3))
def forward(self, x):
pass
if __name__ == "__main__":
model = Model4CXQ()
print()
print("打印参数".center(50,'-'))
for param in model.parameters():
print(param)
print()
print("打印字典".center(50,'-'))
for k, v in model.state_dict().items():
print(k, v)
print()
print("增加属性".center(50,'-'))
attribute4cjhT = torch.Tensor(3)
attribute4cjhP = nn.Parameter(torch.Tensor(2))
model.attribute4cjhT = attribute4cjhT
model.attribute4cjhP = attribute4cjhP
print("打印字典".center(50,'-'))
for k, v in model.state_dict().items():
print(k, v)
print()
print("注册属性".center(50,'-'))
attribute4cjhRT = torch.Tensor(3)
attribute4cjhRP = nn.Parameter(torch.Tensor(2))
# model.register_parameter('attribute4cjhRT', attribute4cjhRT)
# 上一行代码报错,注册的应该是Parameter类型,而不是FloatTensor类型
# 报错信息:
# TypeError: cannot assign 'torch.FloatTensor' object to parameter 'attribute4cjhRT' (torch.nn.Parameter or None required)
model.register_parameter('登记注册attribute4cjhRP属性', attribute4cjhRP)
print("打印字典".center(50,'-'))
for k, v in model.state_dict().items():
print(k, v)
控制台输出:
Windows PowerShell
版权所有 (C) Microsoft Corporation。保留所有权利。
尝试新的跨平台 PowerShell https://aka.ms/pscore6
加载个人及系统配置文件用了 892 毫秒。
(base) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq> & 'D:\Anaconda3\envs\ssd4pytorch1_2_0\python.exe' 'c:\Users\chenxuqi\.vscode\extensions\ms-python.python-2020.12.424452561\pythonFiles\lib\python\debugpy\launcher' '51788' '--' 'c:\Users\chenxuqi\Desktop\News4cxq\test4cxq\test22.py'
-----------------------打印参数-----------------------
Parameter containing:
tensor(20200910., requires_grad=True)
Parameter containing:
tensor(20200., requires_grad=True)
Parameter containing:
tensor([1.4038e+13, 8.6419e+00], requires_grad=True)
-----------------------打印字典-----------------------
attribute4cxq tensor(20200910.)
attribute4lzq tensor(20200.)
attribute4zh tensor([1.4038e+13, 8.6419e+00])
-----------------------增加属性-----------------------
-----------------------打印字典-----------------------
attribute4cxq tensor(20200910.)
attribute4lzq tensor(20200.)
attribute4zh tensor([1.4038e+13, 8.6419e+00])
attribute4cjhP tensor([0., 0.])
-----------------------注册属性-----------------------
-----------------------打印字典-----------------------
attribute4cxq tensor(20200910.)
attribute4lzq tensor(20200.)
attribute4zh tensor([1.4038e+13, 8.6419e+00])
attribute4cjhP tensor([0., 0.])
登记注册attribute4cjhRP属性 tensor([0., 0.])
(ssd4pytorch1_2_0) PS C:\Users\chenxuqi\Desktop\News4cxq\test4cxq>