Neural style之torch学习笔记4:学习神经网络包的用法(2)

转自:http://blog.csdn.net/hungryof/article/details/52027049

总说

上篇博客已经初步介绍了Module类。这里将更加仔细的介绍。并且还将介绍Container, Transfer Functions Layers和 Simple Layers模块。

Module

主要有4个函数。 
1. [output]forward(input) 
2. [gradInput] backward (input.gradOutput) 
以及在后面自动挡训练方式中重载的两个函数 
3. [output] updateOutput (input) 
4. [gradInput] updateGradInput (input, gradOutput) 
注意点: 
1. forward函数的input必须和backward的函数的input一致!否则梯度更新会有问题。 
2. Module的forward会调用updateOutput(input), 而backward会调用[gradInput] updateGradInput (input, gradOutput)和accGradParameters(input, gradOutput) 
3. 高级训练方式只要重载updateOutput和updateGradInput这两个函数,内部参数会自动改变。 
4. 对于第3点,需要更加深入的探讨

Container

复杂的神经网络可以用container类进行构建。 
container的子类主要有三个:Sequential, Parallel, Concat 。他重新实现了Module类的方法。此外还增加了很多方法。 
主要函数: 
1. add(module) 
2. get(index) 
3. size()   return the number of contained modules. 
4. remove(index) 
5. insert (module, [index] ) 注意这里的index是插入后,其排到的index

用法就是,一般是用一个Sequential,然后不断add(module),而module有simple layers和卷积层。在后面进行说明。

Transfer Functions Layers

就是激活函数,在第上一篇博客已经说明了,torch中讲神经网络看成是module(或是container)的组合。你可以加入层模块,或是激活函数的层模块,最后还可以加上criterion层模块。如此一来,整个网络就构建好了。 
这个没啥好说的。 
里面有挺多激活函数的. SoftMax, SoftMin, SoftPlus, LogSigmoid, LogSoftMax, Sigmoid, Tanh, ReLU, PReLU, ELU, LeakyReLU等等。 
这里就拿Tanh举例吧。

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">ii=torch.linspace(-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>)
m=nn.Tanh()
oo=m:forward(ii)
go=torch.ones(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">100</span>)
gi=m:backward(ii,go)
gnuplot.plot({<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'f(x)'</span>,ii,oo,<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'+-'</span>},{<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'df/dx'</span>,ii,gi,<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'+-'</span>})
gnuplot.grid(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">true</span>)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul>

这里写图片描述

Simple Layers

简单层有很多,一些是提供仿射变换的,一些是进行Tensor method的。

  • 具有参数的modules有: 
    Linear 
    Add : 对输入增加一个偏置项 
    Mul 
    CMul 
    等等
  • 进行基本Tensor运算的 
    View 
    Transpose 
    等等
  • 进行数学运算的 
    Max, Min, Exp, Mean, Log, Abs, MM:matrix-matrix multiplication. 
    Normalize: normalize the input to have unit L-p norm
  • 其他 
    Identity 
    Dropout

下面调重要常见的几个

Linear

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.Linear(inputDim, outputDim, [bias = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">true</span>])</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>

Linear就是全连接呗。

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.Linear(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span>)
mlp = nn.Sequential()
mlp:add(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>)
<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>.weight)
<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>.bias)
<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>.gradWeight)
<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>.gradBias)
x = torch.Tensor(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span>) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">-- 10 inputs</span>
y = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>:forward(x)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li></ul>

运行后发现,只要初始化后网络,里面就有初始权值和偏置。但是gradBias和gradGradient不是很大就是很小,显然这是垃圾数据。现在有个问题,网络权值的初始化对整个网络的训练非常重要,torch是怎样自定义初始化权值的呢?我现在还没发现!先做个标记!

Dropout

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.Dropout(p)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>

这个就是Dropout层。简单的说每一个神经元的输入将会以p的概率丢弃。这个方法是一个避免过拟合的正规化方法。可参见Improving neural networks by preventing co-adaptation of feature detectors

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.Dropout()

> x = torch.Tensor{{<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>}, {<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">6</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">7</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">8</span>}}

> <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>:forward(x)
  <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">8</span>
 <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>  <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">14</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>
[torch.DoubleTensor of dimension <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>x4]

> <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>:forward(x)
  <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">6</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>
 <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>   <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>
[torch.DoubleTensor of dimension <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>x4]</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li></ul>

可以看出,上面一些值被丢弃。

View

这个主要是改变网络输出的tensor的sizes,就是reshape一下。

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.View(sizes)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>

如果是用view(-1)则可特别地用作minibatch的输入。

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">x = torch.Tensor(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>)
<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> i = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">do</span>
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> j = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">do</span>
       x[i][j] = (i-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)*<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>+j
    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span>
 <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span>
 <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(nn.View(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">8</span>):forward(x))
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">--[[
  1   2   3   4   5   6   7   8
  9  10  11  12  13  14  15  16
  ]]</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li></ul>

作为minibatch的输入!

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">> input = torch.Tensor(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>)
> minibatch = torch.Tensor(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>)
> m = nn.View(-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>):setNumInputDims(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>)
> <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(#m:forward(minibatch))

 <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span>
 <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">6</span>
[torch.LongStorage of size <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>] <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">--每一行就是一个example的结果。</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li></ul>

Normalize

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.Normalize(p, [eps])</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>

L-p范式进行归一化,eps默认是1e-10,防止输入全为0时除0的情况。

MM

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.MM(transA,transB)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>

transA/transB 为true时,则对应的矩阵进行转置。 
如果是3维矩阵,则第一维被理解成batches的num,转置只会对后面的两维进行。

<code class="language-lua hljs  has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.MM(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">true</span>, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">false</span>)
A = torch.randn(b,n,m)
B = torch.randn(b,n,p)
c = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>:forward({A, B})
<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">--[[
c是 b * m * p的
]]</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul>

简单层一个常见的Module是Add,放在下篇博客讲。因为会用简单层的Add模块讲手动挡训练演示。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值