import numpy as np
# cross entropy error mini batch, if t is one hot
def mini_batch_cee(y, t):
batch_size = y.shape[0]
return -np.sum(t * np.log(y + 1e-7))/batch_size
# cross entropy error mini batch, if t is not one hot, is lable
def mini_batch_cee_sec(y, t):
if y.ndim == 1: # 1维转变为多维
t = t.reshape(1, t.size)
y = y.reshape(1, y.size)
batch_size = y.shape[0]
# np.arange(batch_size) generate [0,1,2,...]
# y[np.arange(batch_size), t] generate [y[0,0],y[0,1]...]
return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7))/batch_size
# mini batch
def mini_batch(y, t):
y_size = y.shape[0]
batch_size = 3
batch_mask = np.random.choice(y_size, batch_size) # 返回下标
y_batch = y[batch_mask]
t_batch = t[batch_mask]
return y_batch, t_batch
y = np.array([0.1, 0.2, 0.6, 0.1]) # 模拟神经网络输出
t = np.array([0, 0, 1, 0]) # one-hot表示
t_1 = np.array([0, 1, 2, 3]) # lable
y_batch, t_batch = mini_batch(y, t)
print("batch select", y_batch, t_batch)
mini_batch_res = mini_batch_cee(y, t)
print("mini batch one hot", mini_batch_res)
mini_batch_res = mini_batch_cee_sec(y, t_1)
print("mini batch lable", mini_batch_res)
mini batch,cross entorpy error
最新推荐文章于 2021-08-27 13:19:27 发布