def cross_entropy_error(y, t):
if y.ndim == 1:
t = t.reshape(1, t.size)
y = y.reshape(1, y.size)
# 监督数据是one-hot-vector的情况下,转换为正确解标签的索引
if t.size == y.size:
t = t.argmax(axis=1)
batch_size = y.shape[0]
return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size
看到这段代码可能不熟悉python的同学可能会犯迷糊,这里我谈一下我的理解。
1.为什么有以下代码?
if y.ndim == 1: t = t.reshape(1, t.size) y = y.reshape(1, y.size)
答:因为现在是对min_batch个数据进行操作,通常数据是大于一个的。此时batch_size等于数据个数。但如果只有一个数据呢?此时如果不加上述代码,batch_size就会等于一维数组的元素个数,产生错误,所以加上 上述代码。
2.t=t.argmax(axis=1)是什么意思?
答:当t是一维数组,也就是只有一个数据时,t返回最大数的下标。
当t是二维数组时,axis等于1,以数组形式返回每一行最大数的下标。
3.np.sum(np.log(y[np.arange(batch_size), t] + 1e-7))是什么意思?
3.1 首先要知道np.arange(batch_size), t] 是什么意思?
它返回的是一个一维数组,数组中的元素是每个数据中最大的元素。
3.2 那么这段代码的意思就是,把每个数据中最大的元素的对数值进行求和。