0阶张量叫标量(scarlar);1阶张量叫向量(vector);2阶张量叫矩阵(matrix)
本文主要内容:如何用python中的theano包实现最基础的分类器–LR(Logistic Regression)。
一、模型
由概率论知识总结出模型,二分类用公式(1),多分类用公式(2);为了求解公式(2)中的最优参数(
W
和
b
),推导出目标函数公式(3)
逻辑回归是一种线性分类器。它的参数包括权重矩阵
W
和偏置
b
。通过将输入向量映射到一组超平面进行多分类(一个超平面可以分两类,多个超平面可以进行多分类)。输入向量到超平面的距离作为输入样本属于对应类别的概率。
输入向量
x
属于第
Y=i
类的概率模型表示如下:
其中, W 和 b 是参数, P(Y=i|x) 是条件概率,意思是在变量 x 的条件下, Y=i 的概率。 举例: P(Y=0|x) 表示输入的样本x被识别为数字0的概率。 这个并不难理解,只要学过概率论的话,不是问题。
softmaxi(Wx+b) 可以理解为 x 属于 i 的概率,具体含义看下边内容。这个表达式更清楚地说明了 x 的运算过程,即: x 与 W 点乘,再与 b 矩阵/向量 相加,然后把结果传入 softmax 分类器,得到分类结果为 i 。那么, softmax 是如何工作的呢(具体含义是什么呢)?就是公式(1)中最终的结果表达式。下面分析这个表达式。
eWix+bi 可以理解为表示样本 x 属于第 i 类的概率。那么 ∑Nj=0eWjx+bj 显然是表示 x 属于每一类的概率之和。为什么两者要做除法呢?答案是: 归一化 。这样,最终 P(Y=i) 的累加和就是1。
首先看看 eWix+bi 是怎么来的。因为逻辑回归的假设函数是 sigmoid 函数,即 h(x)=11+e−(Wx+b) ,而 softmax 是多分类,所以其假设函数就是
从公式(2)很容易看出是如何归一化的了。找到这个列向量中最大值的下标 k ,就代表样本x属于第 k 类的概率最大。因此,模型的最终预测公式为:
argmax 函数的作用是返回矩阵中每一行或每一列最大数的下标。在这里,矩阵 Pm∗n 的元素构成是每一行代表一个样本(即共有 m 个样本),每一列代表当前样本属于该列下标(0-n)类别的概率。 举例:第0行第0列元素是第0行中最大数下标。则表示:第0行所代表的样本被分类为0的概率是最大的。 那么以此类推,最终的结果就表示每一个样本所分 类别 的最大概率。
至此,概率模型已经分析完毕。下边就是如何用python实现计算输入样本的概率值了。代码如下(代码中会涉及到Theano的使用,因此如果有解释不到的函数,请参考我的Theano官方教程翻译及学习笔记系列博文,以上博文因正在编写,临时可参考Theano实现卷积运算代码详解):
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;"><span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 初始化权值W,shared是用来为了GPU运算的。</span> self.W = theano.shared( value=numpy.zeros( (n_in, n_out), dtype=theano.config.floatX), name=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'W'</span>, borrow=<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">True</span> ) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 初始化偏置b</span> self.b = theano.shared( value=numpy.zeros( (n_out,), dtype=theano.config.floatX), name=<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'b'</span>, borrow=<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">True</span> ) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 下面就是概率公式(1)的代码实现,其中dot是点乘运算;input是输入向量x;</span> self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 对输入样本的预测,是公式(3)的代码实现,axis表示函数argmax要按照行返回最大数</span> self.y_pred = T.argmax(self.p_y_given_x, axis=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li></ul>
由于水平有限,思路可能会有一些乱,但是对照公式(1-3)慢慢整理一下还是可以理解的。现在所写的代码只是实现了公式的计算,相当于定义了一个函数,输入样本矩阵x,输出预测矩阵y。但是,如果能预测的准呢?下面且看第二部分。
二、定义代价函数
要最优化公式(3)的参数,那么根据经验编写出代价函数,即公式(4)
由机器学习的知识知道,要想使得模型预测结果最佳,即
W
和
b
取得最佳参数,那么就要根据假设函数定义一个代价函数,当代价函数最小化时,预测结果最优。(为什么要根据假设函数?因为模型公式就是基于假设函数编写的)
在多分类的逻辑回归中,经常用负对数似然函数作为代价函数,记为A。最小化函数A就等价于最大化A中的似然函数。似然函数
L
和代价函数
ℓ
定义如下:
D 是数据集(输入样本集); |D| 表示样本总数; θ 是模型参数(由 W 和 b 构成);
公式 L 表示,先对每一个输入样本进行公式(1)操作,然后对结果取log对数,最后将所有样本的概率对数求和。
公式 ℓ 表示取 L 的负数。
那么,怎么最小化那一堆的非线性函数呢? 梯度下降法 。梯度下降法是到目前为止最简单的用来最小化任意非线性函数的方法。因此,这里也同样采取该方法,不过是经过改进的,即 批量随机梯度下降法(MSGD-stochastic gradient method with mini-batches) 听起来很炫,其实很简单。梯度下降是更新整体样本;随机梯度下降是更新一个样本;而批量随机梯度则是介于两者之间,更新一部分样本。具体解释 看这里 。
现在基本介绍完了代价函数了,下面来看一下代价函数的代码,注意是传入一块(minibatch)样本数据,而不是一个或整个样本。
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;"><span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># y.shape返回y的行数和列数,则y.shape[0]返回y的行数,即样本的总个数,因为一行是一个样本。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># T.arange(n),则是产生一组包含[0,1,...,n-1]的向量。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># T.log(x),则是对x求对数。记为LP</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># LP[T.arange(y.shape[0]),y]是一组向量,其元素是[ LP[0,y[0]], LP[1,y[1]], </span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># LP[2,y[2]], ...,LP[n-1,y[n-1]] ]</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># T.mean(x),则是求向量x中元素的均值。</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>]), y])</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul>
对照公式(4)看一下,
P(Y=yi|xi)
就是传入的self.p_y_given_x;
log(P)
就是T.log(p);
∑
求和就是代码中的 LP[T.arange(y.shape[0]),y];
ℓ=−L
就是return中的取反。
还有一点就是公式中没有求均值,而代码中加入了T.mean的均值运算。这是因为公式(4)还欠缺一部分就是除以
|D|
,因此其完整的公式为:
为什么要用均值,若不是和值呢?这里是因为用的批量随机梯度下降法来最小化代价函数,不同的样本输入块可能对学习速率产生不同的影响。因此用均值是为了降低学习速率对输入样本块的依赖性。
三、创建LR的python类
1. 创建类。
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;"><span class="hljs-class" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">class</span> <span class="hljs-title" style="box-sizing: border-box; color: rgb(102, 0, 102);">LR</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(object)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">""" 逻辑归回的实现是基于博文中给出的公式,需要预先设定好参数W和b。最小化方法用的批量随机梯度下降法MSGD。 因此传入数据是一块一块的。 """</span> <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">__init__</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, input, n_in, n_out)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">""" 初始化函数!此类实例化时调用该函数 按照Python定义类的格式给出如下定义,需要传入的参数分别为: input的类型为 TensorType,类似于形参,起象征性的作用,并不包含真实的数据; input传入值为 minibatch样本数据,该数据是一个m*n的矩阵。m表示此minibatch块共有m个样本;n表示每一个样本的实际数据。 在mnist实验中,n=784=28*28,因为每一张图片是28*28像素的。 n_in 的类型为 int; n_in 传入值为 每个输入样本的单元数(应该是图片的高*宽(28*28=784),但是在我们的实验数据中, 已经把图片数据矩阵存储为了行向量(784*1),因此这个地方传入的就是数据域中的data列的长度, 即n_in=784,具体的样本数据是传入input里面) n_out的类型为 int n_out传入值为 输出结果的类别数,就是数据域中的标签的范围。此处就是0-9共10个数字。所以n_out=10。就是10分类。 """</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 初始化权值矩阵</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># numpy.zeros((m,n),dtype='float32') 是产生一组 m行n列的全0矩阵,每个矩阵元素存储为float32类型。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># shared()函数是将生成的矩阵封装为shared类型,该类型可以用于GPU加速运算,没有其他用途。</span> self.W = theano.shared( value = numpy.zeros( (n_in, n_out), dtype = <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'float32'</span> ), name = <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'W'</span>, borrow = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">True</span> ) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 初始化偏置值</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># b是一个向量,长度为n_out,就是每一种分类都有一个偏置值</span> self.b = theano.shared( value = numpy.zeros( (n_out,), dtype = <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'float32'</span> ), name = <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'b'</span>, borrow = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">True</span> ) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 计算公式(1),具体解释见博文 http://blog.csdn.net/niuwei22007/article/details/47705081</span> self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 计算公式(3)</span> self.y_pred = T.argmax(self.p_y_given_x, aixs = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 组织模型用到的参数,即把W和b组装成list,便于在类外引用。</span> self.param = [self.W, self.b] <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 记录模型的具体输入数据,便于在类外引用</span> self.input = input <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">negative_log_likelihood</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, y)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">""" 负对数似然函数,即代价函数。 需要传入的参数为: y 的类型为 TensorType,类似于形参,起象征性的作用,并不包含真实的数据; y 传入值为 input对应的标签向量,如果input的样本数为m,则input的行数就是m,那么y就是一个m行的列向量。 """</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 计算完整的公式(4),具体解释见博文 http://blog.csdn.net/niuwei22007/article/details/47705081 </span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>]), y]) <span class="hljs-function" style="box-sizing: border-box;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">def</span> <span class="hljs-title" style="box-sizing: border-box;">errors</span><span class="hljs-params" style="color: rgb(102, 0, 102); box-sizing: border-box;">(self, y)</span>:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">""" 误差计算函数。传入的参数参考negative_log_likelihood. 其作用就是统计预测正确的样本数占本批次总样本数的比例。 """</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 检查 传入正确标签向量y和前面做出的预测向量y_pred是否是具有相同的维度。如果不相同怎么去判断某个样本预测的对还是不对?</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># y.ndim返回y的维数</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># raise是抛出异常</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> y.ndim != self.y_pred.ndim: <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">raise</span> TypeError(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"y doesn't have the same shape as self.y_pred"</span>) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 继续检查y是否是有效数据。依据就是本实验中正确标签数据的存储类型是int</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 如果数据有效,则计算:</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># T.neq(y1, y2)是计算y1与y2对应元素是否相同,如果相同便是0,否则是1。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 举例:如果y1=[1,2,3,4,5,6,7,8,9,0] y2=[1,1,3,3,5,6,7,8,9,0]</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 则,err = T.neq(y1,y2) = [0,1,0,1,0,0,0,0,0,0],其中有3个1,即3个元素不同</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># T.mean()的作用就是求均值。那么T.mean(err) = (0+1+0+1+0+0+0+0+0+0)/10 = 0.3,即误差率为30%</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> y.dtype.startswith(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'int'</span>): <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> T.mean(T.neq(self.y_pred, y)) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">else</span>: <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">raise</span> NotImplementedError()</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li><li style="box-sizing: border-box; padding: 0px 5px;">46</li><li style="box-sizing: border-box; padding: 0px 5px;">47</li><li style="box-sizing: border-box; padding: 0px 5px;">48</li><li style="box-sizing: border-box; padding: 0px 5px;">49</li><li style="box-sizing: border-box; padding: 0px 5px;">50</li><li style="box-sizing: border-box; padding: 0px 5px;">51</li><li style="box-sizing: border-box; padding: 0px 5px;">52</li><li style="box-sizing: border-box; padding: 0px 5px;">53</li><li style="box-sizing: border-box; padding: 0px 5px;">54</li><li style="box-sizing: border-box; padding: 0px 5px;">55</li><li style="box-sizing: border-box; padding: 0px 5px;">56</li><li style="box-sizing: border-box; padding: 0px 5px;">57</li><li style="box-sizing: border-box; padding: 0px 5px;">58</li><li style="box-sizing: border-box; padding: 0px 5px;">59</li><li style="box-sizing: border-box; padding: 0px 5px;">60</li><li style="box-sizing: border-box; padding: 0px 5px;">61</li><li style="box-sizing: border-box; padding: 0px 5px;">62</li><li style="box-sizing: border-box; padding: 0px 5px;">63</li><li style="box-sizing: border-box; padding: 0px 5px;">64</li><li style="box-sizing: border-box; padding: 0px 5px;">65</li><li style="box-sizing: border-box; padding: 0px 5px;">66</li><li style="box-sizing: border-box; padding: 0px 5px;">67</li><li style="box-sizing: border-box; padding: 0px 5px;">68</li><li style="box-sizing: border-box; padding: 0px 5px;">69</li><li style="box-sizing: border-box; padding: 0px 5px;">70</li><li style="box-sizing: border-box; padding: 0px 5px;">71</li><li style="box-sizing: border-box; padding: 0px 5px;">72</li><li style="box-sizing: border-box; padding: 0px 5px;">73</li><li style="box-sizing: border-box; padding: 0px 5px;">74</li><li style="box-sizing: border-box; padding: 0px 5px;">75</li><li style="box-sizing: border-box; padding: 0px 5px;">76</li><li style="box-sizing: border-box; padding: 0px 5px;">77</li><li style="box-sizing: border-box; padding: 0px 5px;">78</li><li style="box-sizing: border-box; padding: 0px 5px;">79</li><li style="box-sizing: border-box; padding: 0px 5px;">80</li><li style="box-sizing: border-box; padding: 0px 5px;">81</li><li style="box-sizing: border-box; padding: 0px 5px;">82</li><li style="box-sizing: border-box; padding: 0px 5px;">83</li><li style="box-sizing: border-box; padding: 0px 5px;">84</li><li style="box-sizing: border-box; padding: 0px 5px;">85</li><li style="box-sizing: border-box; padding: 0px 5px;">86</li><li style="box-sizing: border-box; padding: 0px 5px;">87</li><li style="box-sizing: border-box; padding: 0px 5px;">88</li><li style="box-sizing: border-box; padding: 0px 5px;">89</li><li style="box-sizing: border-box; padding: 0px 5px;">90</li><li style="box-sizing: border-box; padding: 0px 5px;">91</li><li style="box-sizing: border-box; padding: 0px 5px;">92</li><li style="box-sizing: border-box; padding: 0px 5px;">93</li><li style="box-sizing: border-box; padding: 0px 5px;">94</li><li style="box-sizing: border-box; padding: 0px 5px;">95</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li><li style="box-sizing: border-box; padding: 0px 5px;">46</li><li style="box-sizing: border-box; padding: 0px 5px;">47</li><li style="box-sizing: border-box; padding: 0px 5px;">48</li><li style="box-sizing: border-box; padding: 0px 5px;">49</li><li style="box-sizing: border-box; padding: 0px 5px;">50</li><li style="box-sizing: border-box; padding: 0px 5px;">51</li><li style="box-sizing: border-box; padding: 0px 5px;">52</li><li style="box-sizing: border-box; padding: 0px 5px;">53</li><li style="box-sizing: border-box; padding: 0px 5px;">54</li><li style="box-sizing: border-box; padding: 0px 5px;">55</li><li style="box-sizing: border-box; padding: 0px 5px;">56</li><li style="box-sizing: border-box; padding: 0px 5px;">57</li><li style="box-sizing: border-box; padding: 0px 5px;">58</li><li style="box-sizing: border-box; padding: 0px 5px;">59</li><li style="box-sizing: border-box; padding: 0px 5px;">60</li><li style="box-sizing: border-box; padding: 0px 5px;">61</li><li style="box-sizing: border-box; padding: 0px 5px;">62</li><li style="box-sizing: border-box; padding: 0px 5px;">63</li><li style="box-sizing: border-box; padding: 0px 5px;">64</li><li style="box-sizing: border-box; padding: 0px 5px;">65</li><li style="box-sizing: border-box; padding: 0px 5px;">66</li><li style="box-sizing: border-box; padding: 0px 5px;">67</li><li style="box-sizing: border-box; padding: 0px 5px;">68</li><li style="box-sizing: border-box; padding: 0px 5px;">69</li><li style="box-sizing: border-box; padding: 0px 5px;">70</li><li style="box-sizing: border-box; padding: 0px 5px;">71</li><li style="box-sizing: border-box; padding: 0px 5px;">72</li><li style="box-sizing: border-box; padding: 0px 5px;">73</li><li style="box-sizing: border-box; padding: 0px 5px;">74</li><li style="box-sizing: border-box; padding: 0px 5px;">75</li><li style="box-sizing: border-box; padding: 0px 5px;">76</li><li style="box-sizing: border-box; padding: 0px 5px;">77</li><li style="box-sizing: border-box; padding: 0px 5px;">78</li><li style="box-sizing: border-box; padding: 0px 5px;">79</li><li style="box-sizing: border-box; padding: 0px 5px;">80</li><li style="box-sizing: border-box; padding: 0px 5px;">81</li><li style="box-sizing: border-box; padding: 0px 5px;">82</li><li style="box-sizing: border-box; padding: 0px 5px;">83</li><li style="box-sizing: border-box; padding: 0px 5px;">84</li><li style="box-sizing: border-box; padding: 0px 5px;">85</li><li style="box-sizing: border-box; padding: 0px 5px;">86</li><li style="box-sizing: border-box; padding: 0px 5px;">87</li><li style="box-sizing: border-box; padding: 0px 5px;">88</li><li style="box-sizing: border-box; padding: 0px 5px;">89</li><li style="box-sizing: border-box; padding: 0px 5px;">90</li><li style="box-sizing: border-box; padding: 0px 5px;">91</li><li style="box-sizing: border-box; padding: 0px 5px;">92</li><li style="box-sizing: border-box; padding: 0px 5px;">93</li><li style="box-sizing: border-box; padding: 0px 5px;">94</li><li style="box-sizing: border-box; padding: 0px 5px;">95</li></ul>
2.实例化LR类的方式:
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;"><span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 因为LR中的input是TensorType类型,因此引用时,也需要定义一个TensorType类型</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># x表示样本的具体数据</span> x = T.matrix(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'x'</span>) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 同样y也应该是一个TensorType类型,是一个向量,而且数据类型还是int,因此定义一个T.ivector。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 其中i表示int,vector表示向量。详细可以参考Theano教程。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># y表示样本的标签。</span> y = T.ivector(<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'y'</span>) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># x就是input样本,是一个矩阵,因此定义一个T.matrix</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># n_in,n_out的取值在此不再赘述,可以翻看上边的博文。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 在实例化时,会自动调用LR中的__init__函数</span> classifier = LR(input=x, n_in=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">28</span>*<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">28</span>, n_out=<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span>) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 代价函数,依据公式(3)计算生成。这是一个符号变量,cost并不是一个具体的数值。当传入具体的数据后,</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 其才会有具体的数据产生。</span> cost = classifier.negative_log_likelihood(y)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li></ul>
四、训练模型
我们先回顾一下如何进行模型训练。
【1】根据概率论的知识总结出一个概率模型,二分类用公式(1);多分类用公式(2);
【2】求解出模型中的参数(即公式(2)中的
W
和
b
),因为参数没有确定值,所以我们的目标是使概率模型产出的概率(即公式(3)的结果)距离正确结果越接近越好。
【3】根据经验以及步骤2中的目标编写出代价函数,即公式(4);通过最小化公式(4)获得最优参数
W
和
b
。
【4】采用批量随机梯度下降法最小化公式(4)。本节主要讲述如何用梯度下降法最小化代价函数。先说一下思路,就是先根据上边计算出的cost(代价函数)对
W
和
b
分别求偏导,然后根据梯度更新
W
和
b
的值。计算误差;再根据新的
W
和
b
求偏导,如此迭代下去,直到误差符合要求或者迭代达到一定次数结束循环,此时的
W
和
b
即可以认为是目前最优的。
若要在大多数的编程语言中实现梯度下降算法,需要手动的推导出梯度表达式
∂ℓ∂W
和
∂ℓ∂b
(
ℓ
就是公式(4)),这是一个非常麻烦的推导,而且最终结果也很复杂,特别是考虑到数值稳定性的问题的时候。
然而,在Theano这个工具中,这个变得异常简单。因为它已经把求梯度这种运算给封装好了,不需要手动推导公式,只需要按照格式传入数据即可。下面来看一下代码。
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;"><span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 对W求导,只需要调用函数T.grad,把用代码计算出的公式(4)的结果作为cost传入(就是前边已经计算出来的cost),</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 指定求(偏)导对象为classifier.W(classifier就是前边自己定义的LR类)</span> g_W = T.grad(cost=cost, wrt=classifier.W) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 对b求偏导,原理一样。</span> g_b = T.grad(cost=cost, wrt=classifier.b)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li></ul>
计算完了梯度,就要根据梯度进行权值偏置值的更新。操作如下:
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;"><span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># updates相当于一个更新器,说明了哪个参数需要更新,以及更新公式</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 下面代码指明更新需要参数W,更新公式是(原值-学习速率*梯度值)</span> updates = [(classifier.W, classifier.W - learning_rate * g_W), <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 参数b的更新类似于W </span> (classifier.b, classifier.b - learning_rate * g_b)]</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li></ul>
现在我们就可以编写模型训练函数了。代码就两句话,但是解释一大堆,希望能帮助初学者了解function的工作原理。
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;"><span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 上边所提到的TensorType都是符号变量,符号变量只有传入具体数值时才会生成新的数据。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># theano.function也是一个特色函数。在本实验中,它会生成一个叫train_model的函数。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 该函数的参数传递入口是inputs,就是将需要传递的参数index赋值给inputs</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 该函数的返回值是通过outputs指定的,也就是返回经过计算后的cost变量。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 更新器updates是用刚刚定义的update</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># givens是一个很实用的功能。它的作用是:在计算cost时会用到符号变量x和y(x并没有显示的表达出来,</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 函数negative_log_likehood用到了p_y_given_x,而计算p_y_given_x时用到了input,input就是x)。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 符号变量经过计算之后始终会有一个自身值,而此处计算cost不用x和y的自身值,那就可以通过givens里边的表达式</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 重新指定计算cost表达式中的x和y所用的值,而且不会改变x和y原来的值。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">##举个简单的例子:</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># state = shared(0)</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># inc = T.iscalar('inc')</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># accumulator = function([inc], state, updates=[(state, state+inc)])</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># state.get_value() #结果是array(0),因为初始值就是0</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># accumulator(1) #会输出结果array(0),即原来的state是0,但是继续往下看</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># state.get_value() #结果是array(1),根据updates得知,state=state+inc=0+1=1</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># accumulator(300) #会输出结果array(1),即原来的state是1,但是继续往下看</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># state.get_value() #结果是array(301),根据updates得知,state=state+inc=1+300=301</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">##此时state=301,继续做实验</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># fn_of_state = state * 2 + inc</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">##foo用来代替更新表达式中的state,即不用state原来的值,而用新的foo值,但是fn_of_state表达式不变</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># foo = T.scalar(dtype=state.dtype)</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">##skip_shared函数是输入inc和foo,输出fn_of_state,通过givens修改foo代替fn_of_state表达式中的state</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># skip_shared = function([inc, foo], fn_of_state, givens=[(state, foo)]) </span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># skip_shared(1, 3) #会输出结果array(7),即fn_of_state=foo * 2 + inc = 3*2+1 = 7</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">##再来看看state的原值是多少呢?</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># state.get_value() #会输出结果array(301),而不是foo的值3</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">##因为每一次都需要用新的x和y去计算cost值,而不是用原来的上一次的x和y去计算,因此需要用到givens</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">##希望通过这个小例子能说清楚givens的作用。</span> train_model = theano.function( inputs = [index], outputs = cost, updates = update, givens = { <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># x:仅仅是表示第一个数据用来代替x,而不去重新声明一个和x结构类型相同的符号变量了;y同理</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># trian_set_x是训练数据集中的x分量,就是样本的数据部分,trian_set_x[a:b]代表取数组中下标从a开始,到下标b之前的数据。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># train_set_y是训练数据集中的y分量,就是样本的标签部分。</span> x: trian_set_x[index * batch_size:(index + <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) * batch_size], y: trian_set_y[index * batch_size:(index + <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) * batch_size] } ) </code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li></ul>
每一次调用train_model(index),都会计算并返回输入样本块的cost,然后执行一次MSGD,并更新
W
和
b
。整个学习算法的一次迭代这样循环调用train_model (总样本数/样本块数)次。假设总样本60000个,一个样本块600个,那么一次迭代就需要调用100次train_model。而模型的训练又需要进行多次迭代,直到达到迭代次数或者误差率达到要求。
五、测试模型
模型测试需要用到LR中的errors函数。下面来看一下测试模型函数test_model和验证模型函数validate_model。有了上面训练模型的基础,相信这个测试模型会很容易理解。
<code class="language-python hljs has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;"><span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 测试模型基本不需要说太多了,首先测试不需要更新数据,因此没有updates,但是测试需要用到givens来代替cost计算公式中x和y的数值。</span> <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 测试模型采用的数据集是测试数据集test_set_x和test_set_y</span> test_model = thenao.function( inputs = [index], outputs = classifier.errors(y), givens = { x: test_set_x[index * batch_size: (index + <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) * batch_size], y: test_set_y[index * batch_size: (index + <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) * batch_size] } ) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;"># 验证模型和测试模型的不同之处在于计算所用的数据不一样,验证模型用的是验证数据集。</span> validate_model = theano.function( inputs=[index], outputs=classifier.errors(y), givens={ x: valid_set_x[index * batch_size: (index + <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) * batch_size], y: valid_set_y[index * batch_size: (index + <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>) * batch_size] } ) </code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right: 1px solid rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li></ul>
一块一块的讲了这么多代码,或许都看晕了,下面就看一下整合之后的代码会更清晰。
六、整合代码
python源代码带注释下载
只有一个lr.py,直接用python命令执行即可。