pytorch接口和java集成_Pytorch: parameters(),children(),modules(),named_*区别

nn.Module vs nn.functional

前者会保存权重等信息,后者只是做运算

parameters()

返回可训练参数

nn.ModuleList vs. nn.ParameterList vs. nn.Sequential

layer_list = [nn.Conv2d(5,5,3), nn.BatchNorm2d(5), nn.Linear(5,2)]

class myNet(nn.Module):

def __init__(self):

super().__init__()

self.layers = layer_list

def forward(x):

for layer in self.layers:

x = layer(x)

net = myNet()

print(list(net.parameters())) # Parameters of modules in the layer_list don't show up.

nn.ModuleList的作用就是wrap pthon list,这样其中的参数会被注册,因此可以返回可训练参数(ParameterList)。

nn.Sequential的作用如下:

class myNet(nn.Module):

def __init__(self):

super().__init__()

self.layers = nn.Sequential(

nn.Relu(inplace=True),

nn.Linear(10, 10)

)

def forward(x):

x = layer(x)

x = torch.rand(10)

net = myNet()

print(net(x).shape)

可以看到Sequential的作用就是按照指定的顺序构建网络结构,得到一个完整的模块,而ModuleList则只是像list那样把元素集合起来而已。

nn.modules vs. nn.children

class myNet(nn.Module):

def __init__(self):

super().__init__()

self.convBN = nn.Sequential(nn.Conv2d(10,10,3), nn.BatchNorm2d(10))

self.linear = nn.Linear(10,2)

def forward(self, x):

pass

Net = myNet()

print("Printing children\n------------------------------")

print(list(Net.children()))

print("\n\nPrinting Modules\n------------------------------")

print(list(Net.modules()))

输出信息如下:

Printing children

------------------------------

[Sequential(

(0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))

(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

), Linear(in_features=10, out_features=2, bias=True)]

Printing Modules

------------------------------

[myNet(

(convBN1): Sequential(

(0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))

(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

)

(linear): Linear(in_features=10, out_features=2, bias=True)

), Sequential(

(0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))

(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

), Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)), BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), Linear(in_features=10, out_features=2, bias=True)]

可以看到children只会返回子元素,子元素可能是单个操作,如Linear,也可能是Sequential。 而modules()返回的信息更加详细,不仅会返回children一样的信息,同时还会递归地返回,例如modules()会迭代地返回Sequential中包含的若干个子元素。

oj9pmgajfh.png

named_*

named_parameters: 返回一个iterator,每次它会提供包含参数名的元组。

In [27]: x = torch.nn.Linear(2,3)

In [28]: x_name_params = x.named_parameters()

In [29]: next(x_name_params)

Out[29]:

('weight', Parameter containing:

tensor([[-0.5262, 0.3480],

[-0.6416, -0.1956],

[ 0.5042, 0.6732]], requires_grad=True))

In [30]: next(x_name_params)

Out[30]:

('bias', Parameter containing:

tensor([ 0.0595, -0.0386, 0.0975], requires_grad=True))

named_modules

这个其实就是把上面提到的nn.modules以iterator的形式返回,每次读取和上面一样也是用next(),示例如下:

In [46]: class myNet(nn.Module):

...: def __init__(self):

...: super().__init__()

...: self.convBN1 = nn.Sequential(nn.Conv2d(10,10,3), nn.BatchNorm2d(10))

...: self.linear = nn.Linear(10,2)

...:

...: def forward(self, x):

...: pass

...:

In [47]: net = myNet()

In [48]: net_named_modules = net.named_modules()

In [49]: next(net_named_modules)

Out[49]:

('', myNet(

(convBN1): Sequential(

(0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))

(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

)

(linear): Linear(in_features=10, out_features=2, bias=True)

))

In [50]: next(net_named_modules)

Out[50]:

('convBN1', Sequential(

(0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))

(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

))

In [51]: next(net_named_modules)

Out[51]: ('convBN1.0', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)))

In [52]: next(net_named_modules)

Out[52]:

('convBN1.1',

BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))

In [53]: next(net_named_modules)

Out[53]: ('linear', Linear(in_features=10, out_features=2, bias=True))

In [54]: next(net_named_modules)

---------------------------------------------------------------------------

StopIteration Traceback (most recent call last)

in

----> 1 next(net_named_modules)

StopIteration:

named_children

同named_modules

参考

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值