scan——Theano中循环的实现

引子

在开始scan函数的设计之前,我们从一个实例出发,首先来看,一个循环需要必备哪些成分:简单的循环累积相乘,计算 Ak ,使用numpy,代码如下。

def power(A, k):
    result = 1
    for i in range(k):
        result *= A
    return result

计算 Ak 的过程中,我们需要对其中三个主要成分进行操纵,

  • 恒定量A:对应于theano.scan函数中的non_sequences参数

  • result设置的初值,由outputs_info指定

  • result的累积相乘值,自动进行

theano.scan函数接口介绍

results, updates = theano.scan(fn=lambda y, p, x_tm2, x_tm1, A: y+p+x_tm2+x_tm1+A,
                    sequences=[Y, P[::-1]],
                    outputs_info=[dict(initial=X, taps=[-2, -1])],
                    non_sequences=A)
  • 参数 fn 表示每次迭代的操作,简单的通过lambda匿名函数对象来定义,注意lambda函数的参数顺序是有要求的

    • 先是sequences提供的(y, p)

    • 然后是outputs_info的参数(x_tm2, x_tm1)

    • 最后是non_sequences的参数(A)

    • 三者默认都是None

  • sequences表示小迭代的序列,序列的第一维(leading dimension)即是需要迭代的次数,如果不显示指定sequences(也即是默认为None),需要通过n_steps显式指定其迭代次数,YP[::-1]的第一维的大小应该相同,如果不同,取两者的较小值,多项式求和的例子中我们会看到这一机制的应用。

  • outputs_info描述的是需要用到前多少次迭代输出的结果,dict(initial=X, taps=[-2, -1])表示使用前一次和倒数第二次输出的结果。如果当前迭代输出为x(t),则计算中会使用x(t-1), x(t-2)

  • sequences相对,non_sequences描述了非序列的输入,即A是一个固定的输入,每次迭代相加的A都是固定的。

实例

计算 Ak

所以,等价的theano代码如下:

import theano
import theano.tensor as T

k = T.iscalar('k')
A = T.vector('A')
result, updates = theano.scan(fn=lambda prior_result, A: prior_result*A, 
        outputs_info=T.ones_like(A), non_sequences=A, n_steps=k)
        # 这里的对应关系为,outputs_info=T.ones_like(A) -> prior_result,类比 result *= A 
        # 将prior_result初始化为1,每次输出的结果继续传递给outputs_info
        # non_sequence=A -> A 
power = theano.function(inputs=[A, k], outputs=result[-1], updates=updates)
power(range(10), 3)

输出为:

Out[16]: array([   0.,    1.,    8.,   27.,   64.,  125.,  216.,  343.,  512.,  729.])

计算多项式的和

f(x)=a0+a1x+a2x2++an1xn1+anxn

  • 我们用sequences标识系数部分 [a0,a1,,an] ,也包括该系数所在的项对应的幂。

  • 自然使用non_sequences标识自由变量 x

  • 多项式求和与前一项的输出无关,

import numpy
import theano
import theano.tensor as T

coefs = T.vector('coefs')
x = T.scalar('x')
max_coef_supported = 10000
components, updates = theano.scan(fn=lambda coef, power, free_var: coef*free_var**power, 
        sequences=[coefs, T.arange(max_coef_supported)],
        outputs_info=None,
        non_sequences=x)
                # sequences=[coefs, T.arange(max_coef_supported)] -> coef, power
                # non_sequences=x -> free_variable
polynomial = components.sum()
calc_poly = theano.function(inputs=[coefs, x], outputs=polynomial)
calc_poly([1, 0, 2], 3)

输出为:

array(19.0)
            # 1*3^0+0*3^1+2*3^2=19

计算序列x(t)=tanh(x(t1).dot(W)+y(t).dot(U)+p(T1).dot(V))

x = T.vector('x')
W = T.matrix('W')
Y = T.matrix('Y')
U = T.matrix('U')
P = T.matrix('P')
V = T.matrix('V')

results, updates = theanp.scan(fn=lambda y, p, x_tm1: T.tanh(T.dot(x_tm1, W)+T.dot(y, U)+T.dot(p, v)),
            sequences=[Y, P[::-1]],
            outputs_info=[x])
compute_seq = theano.function(inputs=[x, W, Y, U, P, V], outputs=results)           

计算矩阵 X 的列的范数

X = T.matrix('X')
results, updates = theano.scan(fn=lambda x: T.sqrt((x**2).sum()), sequences=[X.T])
compute_norm_cols = theano.function(inputs=[X], outputs=[results])

计算矩阵X的迹

X = T.matrix('X')
results, updates = theano.scan(fn=lambda i, f, t: T.cast(X[i, j]+t, floatX), 
            sequences=[T.arange(X.shape[0]), T.arange(X.shape[1])],
            outputs_info=np.asarray(0., dtype=floatX))
result = results[-1]        
theano.function(inputs=[X], outputs=[result])

计算序列x(t) = x(t - 2).dot(U) + x(t - 1).dot(V) + tanh(x(t - 1).dot(W) + b)

X = T.matrix('X')
theano.scan(fn=lambda x_tm2, x_tm1: T.dot(x_tm2, U)+T.dot(x_tm1, V)+T.tanh(x_tm1, W)+b,
            outputs_info=[dict(initial=X, taps=[-2, -1])], 
            n_steps=n_sym)

References

[1] Theano2.1.10-基础知识之循环

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

五道口纳什

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值