# coding: utf-8
import numpy as np
# ---------------------------激活函数定义---------------------------------------------------
def softmax(a):
"""解决softmax函数的溢出问题,利用c(输入的最大值),softmax函数减去这个最大值保证数据不溢出,softmax函数运算时加上或者
减去某个常数并不会改变运算的结果"""
c = np.max(a)
exp_a = np.exp(a - c) # 溢出对策
sum_exp_a = np.sum(exp_a)
y = exp_a / sum_exp_a
return y
# ---------------------------定义损失函数---------------------------------------------------
def cross_entropy_error(y, t):
delta = 1e-07 # 设置一个微小值,避免对数的参数为0导致无穷大
return - np.sum(t * np.log(y + delta)) # 注意这个log对应的是ln
# -----------------------------------------------------------------------
class SoftmaxWitgLoss:
def __init__(self):
self.loss = None # 损失
self.y = None # softmax的输出
self.t = None # 监督数据(one-hot )
def forward(self, x, t):
self.t = t
self.y = softmax(x)
self.loss = cross_entropy_error(self.y, self.t)
return self.loss
def backward(self, dout=1):
batch_size = self.t.shape[0]
dx = (self.y - self.t) / batch_size # 得到单个数据的误差
return dx