项目实训(二十)torch的常用模块学习

一、nn模块

nn.Identity()

这个函数建立一个输入模块,什么都不做,通常用在神经网络的输入层。用法如下:

mlp = nn.Identity()

print(mlp:forward(torch.ones(5, 2)))

这个可以用在残差学习中。

如果输入需要有多个保存,也可以使用多个nn.Identity()

mlp = nn.Identity()

nlp = nn.Identity()

print(mlp:forward(torch.ones(5, 2)))

多个输入可以在神经网络搭建中起到很好的作用,相当于一个容器,把输入都保留下来了。

可以看一下LSTM中的例子,因为LSTM是循环网络,需要保存上一次的信息,nn.Identity()能够很好的保留信息。

local inputs = {}

table.insert(inputs, nn.Identity()()) -- network input

table.insert(inputs, nn.Identity()()) -- c at time t-1

table.insert(inputs, nn.Identity()()) -- h at time t-1

local input = inputs[1]

local prev_c = inputs[2]

local prev_h = inputs[3]
 
th>LSTM = require'LSTM.lua' [0.0224s]

th> layer = LSTM.create(3, 2)

[0.0019s]

th> layer:forward({torch.randn(1,3), torch.randn(1,2), torch.randn(1,2)})

{

1 : DoubleTensor - size: 1x2

2 : DoubleTensor - size: 1x2

}

[0.0005s]

nn.Squeeze()

可以把输入中的一维的那一层去除。可以直接来看一下官网上的例子:


x=torch.rand(2,1,2,1,2)

> x

(1,1,1,.,.) =

0.6020 0.8897


(2,1,1,.,.) =

0.4713 0.2645


(1,1,2,.,.) =

0.4441 0.9792


(2,1,2,.,.) =

0.5467 0.8648

[torch.DoubleTensor of dimension 2x1x2x1x2]

 

其具体形状是这样的:

 
+-------------------------------+

| +---------------------------+ |

| | +-----------------------+ | |

| | | 0.6020 0.8897 | | |

| | +-----------------------+ | |

| | +-----------------------+ | |

| | | 0.4441 0.9792 | | |

| | +-----------------------+ | |

| +---------------------------+ |

| |

| +---------------------------+ |

| | +-----------------------+ | |

| | | 0.4713 0.2645 | | |

| | +-----------------------+ | |

| | +-----------------------+ | |

| | | 0.5467 0.8648 | | |

| | +-----------------------+ | |

| +---------------------------+ |

+-------------------------------+

 

进行nn.squeeze()操作


> torch.squeeze(x)

(1,.,.) =

0.6020 0.8897

0.4441 0.9792


(2,.,.) =

0.4713 0.2645

0.5467 0.8648

[torch.DoubleTensor of dimension 2x2x2]

 
+-------------------------------+

| 0.6020 0.8897 |

| 0.4441 0.9792 |

+-------------------------------+

+-------------------------------+

| 0.4713 0.2645 |

| 0.5467 0.8648 |

+-------------------------------+

nn.JoinTable()

这个相当于tensorflow的concat操作,但是个人觉得没有concat的操作好用。整体来说torch的代码都没有tensorflow简洁,但是效率比较高。

module = JoinTable(dimension, nInputDims)
 
+----------+ +-----------+

| {input1, +-------------> output[1] |

| | +-----------+-+

| input2, +-----------> output[2] |

| | +-----------+-+

| input3} +---------> output[3] |

+----------+ +-----------+

例子如下:

 
x = torch.randn(5, 1)

y = torch.randn(5, 1)

z = torch.randn(2, 1)



print(nn.JoinTable(1):forward{x, y})

print(nn.JoinTable(2):forward{x, y})

print(nn.JoinTable(1):forward{x, z})



>1.3965

0.5146

-1.5244

-0.9540

0.4256

0.1575

0.4491

0.6580

0.1784

-1.7362

[torch.DoubleTensor of dimension 10x1]



1.3965 0.1575

0.5146 0.4491

-1.5244 0.6580

-0.9540 0.1784

0.4256 -1.7362

[torch.DoubleTensor of dimension 5x2]



1.3965

0.5146

-1.5244

-0.9540

0.4256

-1.2660

1.0869

[torch.Tensor of dimension 7x1]

nn.gModel()

nngraph(nn) 是一个基于有向无环图的模块,所有的节点建立完后,需要使用nn.gModel()组成一个图。

module=nn.gModule(input,output)

这里的input 和output既可以是元素,也可以是列表。这个函数会生成一个从input到output的图。其中此前的每一个模块后面加上该模块输入,成为这个图中的节点。 
给出一个简单的例子:

 
x1 = nn.Identity()()

x2 = nn.Identity()()

a = nn.CAddTable()({x1, x2})

m = nn.gModule({x1, x2}, {a})

图示下:

 
_|__ __|__

| | | |

|____| |____|

| x1 | x2

\ /

\z /

_\ /_

| |

|____|

|a

nn.SpatialConvolution()

 
module = nn.SpatialConvolution(nInputPlane, nOutputPlane, kW, kH, [dW], [dH], [padW], [padH])

or cudnn.SpatialConvolution(nInputPlane, nOutputPlane, width, height, [dW = 1], [dH = 1], [padW = 0], [padH = 0],[groups=1])
  • nInputPlane: The number of expected input planes in the image given into forward().
  • nOutputPlane: The number of output planes the convolution layer will produce.
  • kW: The kernel width of the convolution
  • kH: The kernel height of the convolution
  • dW: The step of the convolution in the width dimension. Default is 1.
  • dH: The step of the convolution in the height dimension. Default is 1.
  • padW: The additional zeros added per width to the input planes. Default is 0, a good number is (kW-1)/2.
  • padH: The additional zeros added per height to the input planes. Default is padW, a good number is (kH-1)/2.
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值