我可能刚看到theano 的scan()的时候大家也是一脸懵逼,这么多参数,而且每个参数那么复杂, 这是都啥啊?
theano.scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False, return_list=False)
没事,我们先来总体了解一下。
scan
- looping in theano
看到这就话的时候我们就知道这个函数是用来做迭代计算的了,这是一种适应GPU的高效计算方式,具体为什么这样设计应该都知道。
输入
- fn - 迭代函数
- sequences - fn输入, 迭代变量
- output-info - fn输入, 初始化fn的输出格式
- non-sequence - fn 输入, 但是迭代过程不变, 相当于static 变量
- n-steps 迭代次数
输出
- result: 每次迭代的输出结果
- updates: 确定再scan中的共享变量的更新规则
后面还有很多参数,但是我们就不详细叙述了,了解这些变量就足够了,当然这样介绍可能大家还是不懂,那么我们直接通过代码来讲解。
Code 1
计算 A k A^k Ak
import theano
import theano.tensor as T
k = T.iscalar("k")
A = T.vector("A")
# Symbolic description of the result
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)
# We only care about A**k, but scan has provided us with A**1 through A**k.
# Discard the values that we don't care about. Scan is smart enough to
# notice this and not waste memory saving them.
final_result = result[-1]
# compiled function that returns A**k
power = theano.function(inputs=[A,k], outputs=final_result, updates=updates)
print(power(range(10),2))
print(power(range(10),4))
迭代函数:
fn=lambda prior_result, A: prior_result * A
fn 函数是计算两个矩阵相乘的一个lambda 表达式, prior_result是上一次这个函数的输出, 即 r e s u l t k − 1 result_{k-1} resultk−1。但是问题是凭啥拟定一个piror_result就是前一次计算的输出呢?那是因为在compile的时候我们调用了updates,这个记录的就是共享变量的更新规则。
outputs_info = T.ones_like(A)
输出初始化,就是如果这个函数什么都不干,他的输出格式就是one_like(A)
Code2
import numpy
coefficients = theano.tensor.vector("coefficients")
x = T.scalar("x")
# sequences最多迭代多少次
max_coefficients_supported = 10000
# Generate the components of the polynomial
components, updates = theano.scan(fn=lambda coefficient, power, free_variable: coefficient * (free_variable ** power),
outputs_info=None,
sequences=[coefficients, T.arange(max_coefficients_supported)],
non_sequences=x)
# Sum them up
polynomial = components.sum()
# Compile a function
calculate_polynomial = theano.function(inputs=[coefficients, x], outputs=polynomial)
# Test
test_coefficients = numpy.asarray([1, 0, 2], dtype=numpy.float32)
test_value = 3
print(calculate_polynomial(test_coefficients, test_value))
print(1.0 * (3 ** 0) + 0.0 * (3 ** 1) + 2.0 * (3 ** 2))
在这个例子中,fn
函数按顺序有三输入(coeff, power, free_var)
,那么他们分别由那些值提供呢?
theano.function(inputs=[coefficients, x], outputs=polynomial)
由上面这行代码可以知道,函数的输入值有两个,即coeff
和x
。而coeff 是一个一维数组,与之对应的就是sequences
; x是一个固定的, 与之对应的则是non-sequences
。
所以这三个值是这么产生的:
coeff中取第一个数为:1
arange 中取第一个数为:0
x 始终不变为:3
计算结果result[0] = 1 * (3 ** 0)
coffe 中取第二个数 为:0
arange 中取第二个数 为:1
x 始终不变为:3
计算结果result[1] = 0 * (3 **1)
。。。。
想必大家应该都比较了解前几个参数的代表的意义了。
Code 3
这是一个替换矩阵中的某个数的示例程序,我们知道至少需要提供两个值,替换值的位置和替换的值。程序如下:
location = T.imatrix("location")
values = T.vector("values")
output_model = T.matrix("output_model")
def set_value_at_position(a_location, a_value, output_model):
zeros = T.zeros_like(output_model)
zeros_subtensor = zeros[a_location[0], a_location[1]]
return T.set_subtensor(zeros_subtensor, a_value)
result, updates = theano.scan(fn=set_value_at_position,
outputs_info=None,
sequences=[location, values],
non_sequences=output_model)
assign_values_at_positions = theano.function(inputs=[location, values, output_model], outputs=result)
# test
test_locations = numpy.asarray([[1, 1], [2, 3]], dtype=numpy.int32)
test_values = numpy.asarray([42, 50], dtype=numpy.float32)
test_output_model = numpy.zeros((5, 5), dtype=numpy.float32)
print(assign_values_at_positions(test_locations, test_values, test_output_model))
迭代函数有三个输入(a_location, a_value, output_model)
我们看scan 的定义:
scan(fn=set_value_at_position, outputs_info=None, sequences=[location, values],non_sequences=output_model)
因为fn
的输入就是
s
e
q
u
n
c
e
1
+
s
e
q
u
e
n
c
e
2
+
.
.
.
.
+
s
e
q
u
e
n
c
e
n
+
m
o
d
e
l
_
i
n
f
o
+
n
o
n
_
s
e
q
u
e
n
c
e
s
sequnce_1 + sequence_2 + ....+ sequence_n + model\_info + non\_sequences
sequnce1+sequence2+....+sequencen+model_info+non_sequences。
所以我们的函数第一个输入就是
([1, 1], 42, zeros((5, 5))
输出为:
[
0.0.0.0.0.
]
[
0.42.0.0.0.
]
[
0.0.0.0.0.
]
[
0.0.0.0.0.
]
[
0.0.0.0.0.
]
[ 0. 0. 0. 0. 0.] \\ [ 0. 42. 0. 0. 0.] \\ [ 0. 0. 0. 0. 0.] \\ [ 0. 0. 0. 0. 0.] \\ [ 0. 0. 0. 0. 0.]
[0.0.0.0.0.][0.42.0.0.0.][0.0.0.0.0.][0.0.0.0.0.][0.0.0.0.0.]
以此类推。
看到这里大该怎么操作的应该都懂了吧!