非常费解的一行代码,python语言真的很妖

笔者近期学习深度学习,遇到一个识别手写字的代码,用theano写的,遇到了下面一行代码,先上代码。

cost = -T.mean(T.log(model)[T.arange(y.shape[0]), y]) 

一行代码让我费解了很久,首先T.mean()、T.log()、T.arange()都能看懂,全部拿起来就看不太懂。费解的地方在于中括号后面那个y, 一般一个中括号中的内容代表一维数组的下标,没搞懂这里为什么会两个元素。y.shape[0]是y的长度,当y.shape[0]为length时,T.arange(y.shape[0])为0,1,~,length-1。接着想上面的中括号里面为什么有两个元素,再回到代码上下文表示T.log(model)是一个矩阵(二维张量,也可以理解成数组),这么一想[T.arange(y.shape[0]),y]是定位二维数组的某一个元素,这个思路是正确的,笔者是查了很多资料证实的,并且写了一个小例子证实。

接下来有另一个问题,[T.arange(y.shape[0]), y]中T.arange(y.shape[0])是一个0~length-1的数组,y也是一个数组,怎么定位,在python中range一般用于一个循环,上面的式子就是T.log(model)[0][y[0]],T.log(model)[1][y[1]],…,T.log(model)[length-1][y[length-1]]。最后再求平均值。

为了验证以上想法,写了一个验证代码如下:

import theano.tensor as T
import numpy as np

model =np.array([[0.1,0.2,0.3,0.25,0.15],[0.23,0.18,0.26,0.1,0.13]])
y = np.array([[3],[4]])
LP = T.log(model)
print(LP.eval())
'''
[[-2.30258509 -1.60943791 -1.2039728  -1.38629436 -1.89711998]
 [-1.46967597 -1.71479843 -1.34707365 -2.30258509 -2.04022083]]
'''
cost = -T.mean(T.log(model)[T.arange(y.shape[0]), y])
print( cost.eval() )
'''
1.844439727056968
'''

结论:

对于一个写了差不多8年的java的人,真的感觉到了python的强大,语法简直不要太妖。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

东心十

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值