Theano scan

本次的博客内容是根据自己对theano的scan官方教程的总结。点击scan官方教程

scan函数在theano中提供循环迭代。
scan的函数签名如下:
theano.scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False)

其中fn可以是对序列进行迭代出来lamda表达式也可以是函数, sequeces是要进行的迭代序列,outputs_info是对每次跌代要使用的前一次迭代的在开始时的初始化数据。non_sequences则是非序列的迭代使用的参数,strict禁用不在non_sequences中指定就可以访问之前定义的shared变量从而优化代码。allow_gc设置为false可以禁用scan里的垃圾回收,从而优化代码。

对于fn指定的函数,它的参数有着顺序要求,即squences, prior results, non_squences.
如果多个squence的长度不一样,迭代的次数则为最短的序列的长度。
outputs_info的类型与每次的迭代的返回的类型相同,即使是能够隐式的转换也行。可通过如下代码实现

# outputs_info = T.as_tensor_variable(0)

outputs_info = T.as_tensor_variable(np.asarray(0, seq.dtype))   #seq 具有和迭代结果相同的类型

在scan使用到taps values时, a_tm2代表a(t-2) , b_tp3代表b(t+3)

Note the order in which the parameters are given, and in which the result is returned. Try to respect chronological order among the taps ( time slices of sequences or outputs) used. For scan is crucial only for the variables representing the different time taps to be in the same order as the one in which these taps are given. Also, not only taps should respect an order, but also variables, since this is how scan figures out what should be represented by what
一个示例代码如下:

def oneStep(u_tm4, u_t, x_tm3, x_tm1, y_tm1, W, W_in_1, W_in_2,  W_feedback, W_out):

  x_t = T.tanh(theano.dot(x_tm1, W) + \
               theano.dot(u_t,   W_in_1) + \
               theano.dot(u_tm4, W_in_2) + \
               theano.dot(y_tm1, W_feedback))
  y_t = theano.dot(x_tm3, W_out)

  return [x_t, y_t]


W = T.matrix()
W_in_1 = T.matrix()
W_in_2 = T.matrix()
W_feedback = T.matrix()
W_out = T.matrix()

u = T.matrix() # it is a sequence of vectors
x0 = T.matrix() # initial state of x has to be a matrix, since
                # it has to cover x[-3]
y0 = T.vector() # y0 is just a vector since scan has only to provide
                # y[-1]


([x_vals, y_vals], updates) = theano.scan(fn=oneStep,
                                          sequences=dict(input=u, taps=[-4,-0]),
                                          outputs_info=[dict(initial=x0, taps=[-3,-1]), y0],
                                          non_sequences=[W, W_in_1, W_in_2, W_feedback, W_out],
                                          strict=True)
     # for second input y, scan adds -1 in output_taps by default
优化scan代码的方法有
  1. Minimizing Scan usage
  2. Explicitly passing inputs of the inner function to scan
  3. Deactivating garbage collecting in Scan
  4. Graph optimizations
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值