torch container && Table Containers

1
container直接对输入的Tensor进行操作,而Table Containers则是对输入的table进行操作。

2

1) 最常见的container是nn.Sequential,他的目的是将模块串联起来,是串联!

mlp = nn.Sequential()
mlp:add(nn.Linear(10, 25)) -- Linear module (10 inputs, 25 hidden units)
mlp:add(nn.Tanh())         -- apply hyperbolic tangent transfer function on each hidden units
mlp:add(nn.Linear(25, 1))  -- Linear module (25 inputs, 1 output)

2) 第二个常见的container是Parallel(input Dimension,outputDimension),他的意思是将输入沿着input Dimension切开,将他的第i个child应用在切开的第i份数据上,然后最后沿着outputDimension concat在一起

mlp = nn.Parallel(2,1);   -- Parallel container will associate a module to each slice of dimension 2
                           -- (column space), and concatenate the outputs over the 1st dimension.

mlp:add(nn.Linear(10,3)); -- Linear module (input 10, output 3), applied on 1st slice of dimension 2
mlp:add(nn.Linear(10,2))  -- Linear module (input 10, output 2), applied on 2nd slice of dimension 2

                                  -- After going through the Linear module the outputs are
                                  -- concatenated along the unique dimension, to form 1D Tensor
> mlp:forward(torch.randn(10,2)) -- of size 5.
-0.5300
-1.1015
 0.7764
 0.2819
-0.6026

将输入randn(10,2)沿着第二维切开,产生两个randn(10,1),然后将对应的(10,1)应用在对应的子模块上,也即第一个(10,1)应用在nn.Linear(10,3)上面,第二个(10,1)应用在nn.Linear(10,2)上面,最后产生一个大小为(3,1)的结果,一个大小为(2,1)的结果,然后沿着第一维串联在一起
3)concat之前说过了

container的性质:
1)container是从module继承而来,module有的性质它都有
2)get(index),获得container中index处的模块
3)size(),获得container中module的数量

对于nngraph构建的网络的self.model:listModules()和self.model:get(iii)的区别,nngraph是一个container,
self.model:get()是获得nngraph中的Node
self.model:listModules()则是获得整个model的所有元素,对于一个Node可以继续肢解,直到肢解成最小的单位container或者单一的节点,所以总体来说第二种方式会获得更多的肢解信息,因为他的目的是将网络肢解成一个个最小的模块,而get()函数相对于nngraph来讲则只是获得相应的nngraph Node,对于Node它不再继续分解,如果一个nn.Sequential包含了很多内部子模块,但是由于它是一个Node,所以将不会进行肢解。

        threshold_nodes, container_nodes = self.model:findModules('cudnn.SpatialConvolution')
        for i = 1,#threshold_nodes do
          print(threshold_nodes[i])
          print(container_nodes[i]) 
        end

self.model:findModules则会返回查找的module,并且找出对应的container,对于没有container作为父节点的,自己就可以看作是container了,这个函数是查找网络中所有的module,所以必须对网络进行完全的肢解才可以,与self.model:listModules()是对应的,都是container父类module类拥有的性质,它自己继承过来了

3 一般打印网络输出就是用

        for iii = 1,self.model:size(),1 do 
            print(iii,self.model:get(iii))
        end

找到对应模块的标号也即iii,打印相应内容即可

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值