import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.add_module("conv", nn.Conv2d(10, 20, 4))
self.add_module("conv1", nn.Conv2d(20 ,10, 4))
model = Model()
for module in model.modules():
print(module)
Model (
(conv): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
(conv1): Conv2d(20, 10, kernel_size=(4, 4), stride=(1, 1))
)
Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
Conv2d(20, 10, kernel_size=(4, 4), stride=(1, 1))
可以看出,modules()返回的iterator不止包含子模块。这是和childern()的不同。NOTE:重复的模块只被返回一次(children()也是)。在下面的例子中submodule只会被返回一次。
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
submodule = nn.Conv2d(10, 20, 4)
self.add_module("conv", submodule)
self.add_module("conv1", submodule)
model = Model()
for module in model.modules():
print(module)
Model (
(conv): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1)) , →
(conv1): Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1)) , →
)
Conv2d(10, 20, kernel_size=(4, 4), stride=(1, 1))
named_children()
返回包含模型当前子模块的迭代器,yield模块名字和模块本身。 例子:
for name, module in model.named_children():
if name in ['conv4', 'conv5']:
print(module)
named_modules(memo=None, prefix=”)
返回包含网络中所有模块的迭代器
, yield
ing
模块名和模块本身。
重复的模块只被返回一次
(children()
也是
)
。在下面的例子中
, submodule 只会被返回一次。
--
parameters(memo=None)
返回一个 包含模型所有参数 的迭代器。
一般用来当作optinizer的参数。
例子:
for param in model.parameters():
print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
register_backward_hook(hook)
在
module
上注册一个
bachward hook
。
每次计算
module
的
inputs
的梯度的时候,这个
hook
会被调用。
hook
应该拥有
下面的
signature
。
hook(module, grad_input, grad_output) -> Variable or None
如果
module
有多个输入输出的话,那么
grad_input grad_output
将会是个
tuple
。
hook
不应该修改它的
arguments
,但是它可以选择性的返回关于输入的梯度,这
个返回的梯度在后续的计算中会替代
grad_input
。
这个函数返回一个 句柄
(handle)
。它有一个方法
handle.remove()
,可以用这个
方法将
hook
从
module
移除。
–
register_buffer(name, tensor)
给
module
添加一个
persistent buffer
。
persistent buffer
通常被用在这么一种情况:我们需要保存一个状态,但是这个
状态不能看作成为模型参数。例如:
, BatchNorm
’
s running_mean
不是一个
parameter,
但是它也是需要保存的状态之一。
Buffers
可以通过注册时候的
name
获取。
NOTE:
我们可以用
buffer
保存
moving average
例子:
self.register_buffer('running_mean',
torch.zeros(num_features))
self.running_mean