笔者近期学习深度学习,遇到一个识别手写字的代码,用theano写的,遇到了下面一行代码,先上代码。
cost = -T.mean(T.log(model)[T.arange(y.shape[0]), y])
一行代码让我费解了很久,首先T.mean()、T.log()、T.arange()都能看懂,全部拿起来就看不太懂。费解的地方在于中括号后面那个y, 一般一个中括号中的内容代表一维数组的下标,没搞懂这里为什么会两个元素。y.shape[0]是y的长度,当y.shape[0]为length时,T.arange(y.shape[0])为0,1,~,length-1。接着想上面的中括号里面为什么有两个元素,再回到代码上下文表示T.log(model)是一个矩阵(二维张量,也可以理解成数组),这么一想[T.arange(y.shape[0]),y]是定位二维数组的某一个元素,这个思路是正确的,笔者是查了很多资料证实的,并且写了一个小例子证实。
接下来有另一个问题,[T.arange(y.shape[0]), y]中T.arange(y.shape[0])是一个0~length-1的数组,y也是一个数组,怎么定位,在python中range一般用于一个循环,上面的式子就是T.log(model)[0][y[0]],T.log(model)[1][y[1]],…,T.log(model)[length-1][y[length-1]]。最后再求平均值。
为了验证以上想法,写了一个验证代码如下:
import theano.tensor as T
import numpy as np
model =np.array([[0.1,0.2,0.3,0.25,0.15],[0.23,0.18,0.26,0.1,0.13]])
y = np.array([[3],[4]])
LP = T.log(model)
print(LP.eval())
'''
[[-2.30258509 -1.60943791 -1.2039728 -1.38629436 -1.89711998]
[-1.46967597 -1.71479843 -1.34707365 -2.30258509 -2.04022083]]
'''
cost = -T.mean(T.log(model)[T.arange(y.shape[0]), y])
print( cost.eval() )
'''
1.844439727056968
'''
结论:
对于一个写了差不多8年的java的人,真的感觉到了python的强大,语法简直不要太妖。