题主没有对符号作必要的说明,先按我的理解对符号进行定义
z是softmaxwithloss层的输入,f(z)是softmax的输出,即
y是输入样本z对应的类别,y=0,1,...,N
对于z,其损失函数定义为
展开上式:
对上式求导,有
梯度下降方向即为
====================================
增加关于softmax的反向传播说明
设softmax的输出为a,输入为z,损失函数为l
则
其中
在caffe中是top_diff,a为caffe中得top_data,需要计算的是
if i!=k
if i==k
于是
整理一下得到
其中
表示将标量扩展为n维向量,
表示向量按元素相乘
对照caffe源码
caffe_copy(top[0]->count(), top_diff, bottom_diff);
for (int i = 0; i < outer_num_; ++i)
{
// compute dot(top_diff, top_data) and subtract them from the bottom diff
for (int k = 0; k < inner_num_; ++k)
{
scale_data[k] = caffe_cpu_strided_dot(channels,
bottom_diff + i * dim + k, inner_num_,
top_data + i * dim + k, inner_num_);
}
//此处计算点积,注意到top_diff已经拷贝到bottom_diff
// subtraction
caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, channels, inner_num_, 1,
-1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);
//此处计算大括号内的减法
}
// elementwise multiplication
caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff);
//此处计算大括号外和
的乘法