导数的链式法则的理解,我参考了一篇 文章 Hacker's guide to Neural Networks
var x = -2, y = 5,z=-4,h = 0.1;// whenever value of notation h was changeed, but the following derivateive value was not changed.
var i=1;
while (i<=10)
{
var out0=forwardCircuit(x, y, z);
console.log("第"+i+"次输出为:"+out0);
var step_size=h;
var der=derivative(x,y,z)
x += step_size * der.x_derivative;//
y += step_size * der.y_derivative;
i++;
}
function derivative(x,y,z){
var q=forwardAddGate(x,y);
var derivative_f_wrt_q = MultiplyGateDerivative(q,z,h).x_derivative;//(x+y)z 函数对 x+y函数的导数=z
var derivative_q_wrt_x = AddGateDerivative(x,y,h).x_derivative;// x+y 函数对x 的导数为1
var derivative_q_wrt_y = AddGateDerivative(x,y,h).y_derivative;// x+y 函数对y的导数为1
//var x_derivative=MultiplyGateDerivative(x,y,h).x_derivative;
//var y_derivative=MultiplyGateDerivative(x,y,h).y_derivative;
var derivative_f_wrt_x = derivative_q_wrt_x * derivative_f_wrt_q; //(x+y)z 函数对 x 链式法则 =1xz
var derivative_f_wrt_y = derivative_q_wrt_y * derivative_f_wrt_q; // (x+y)z 函数对 y 链式法则链式法则 =1xz
console.log("乘法门的对x的斜率为"+MultiplyGateDerivative(q,z,h).x_derivative);
console.log("乘法门的对y的斜率为"+derivative_f_wrt_q);
console.log("加法门的对x的斜率为"+derivative_q_wrt_x);
console.log("加法门的对y的斜率为"+derivative_q_wrt_y);
console.log("走"+derivative_f_wrt_x);
console.log("加法门的对y的斜率为"+derivative_q_wrt_y);
return {x_derivative:derivative_f_wrt_x,y_derivative:derivative_f_wrt_y};
}
function MultiplyGateDerivative(x,y,h){//计算乘法的斜率
// compute derivative with respect to x
var out = forwardMultiplyGate(x, y); // -6
var xph = x + h; //
var out2 = forwardMultiplyGate(xph, y); //
var x_derivative = (out2 - out) / h; // 3.0
// compute derivative with respect to y
var yph = y + h; // 3.0001
var out3 = forwardMultiplyGate(x, yph); //
var y_derivative = (out3 - out) / h; // -2.0
return {x_derivative:x_derivative,y_derivative:y_derivative};
}
function AddGateDerivative(x,y,h){ // 计算加法的斜率
// compute derivative with respect to x
var out = forwardAddGate(x, y); // -6
var xph = x + h; // -1.9999
var out2 = forwardAddGate(xph, y); // -5.9997
var x_derivative = (out2 - out) / h; // 3.0
// compute derivative with respect to y
var yph = y + h; // 3.0001
var out3 = forwardAddGate(x, yph); // -6.0002
var y_derivative = (out3 - out) / h; // -2.0
return {x_derivative:x_derivative,y_derivative:y_derivative};
}
function forwardMultiplyGate(x, y) { return x * y; };//乘法向前传播
function forwardAddGate(a, b) { return a + b;};//加法向前传播
function forwardCircuit(x,y,z) { //(x+y)z
var q = forwardAddGate(x, y);
var f = forwardMultiplyGate(q, z);
return f;
};
function GradientCheck(){// 梯度检查
var x_derivative = (forwardCircuit(x+h,y,z) - forwardCircuit(x,y,z)) / h; // -4
var y_derivative = (forwardCircuit(x,y+h,z) - forwardCircuit(x,y,z)) / h; // -4
var z_derivative = (forwardCircuit(x,y,z+h) - forwardCircuit(x,y,z)) / h; // 3
}
就以上代码进行了优化
var x = -7, y = 6,z=-4,h = 0.1;// whenever value of notation h was changeed, but the following derivateive value was not changed.
var i=1;
while (i<=10)
{
var out0=forward(x, y, z);
console.log("第"+i+"次输出为:"+out0);
var step_size=h;
var der=derivative(x,y,z)
x += step_size * der.x_derivative;//
y += step_size * der.y_derivative;
i++;
}
function derivative(x,y,z){
var der_q = DerMul(forwardAdd(x,y),z).x_derivative;//(x+y)z 函数对 x+y函数的导数=z
var der_z= DerMul(forwardAdd(x,y),z).y_derivative;
var der_x =1* der_q;//DerAdd(x,y).x_derivative;// x+y 函数对x 的导数为1 //(x+y)z 函数对 x 链式法则 =1xz
var der_y =1* der_q;// DerAdd(x,y).y_derivative;// x+y 函数对y的导数为1 // (x+y)z 函数对 y 链式法则链式法则 =1xz
return {x_derivative:der_x,y_derivative:der_y};
}
function DerMul(x,y){//计算乘法的斜率
return {x_derivative:y,y_derivative:x};
}
function DerAdd(x,y){ // 计算加法的斜率
return {x_derivative:1,y_derivative:1};
}
function forwardMul(x, y) { return x * y; };//乘法向前传播
function forwardAdd(a, b) { return a + b;};//加法向前传播
function forward(x,y,z) { //(x+y)z
var q = forwardAdd(x, y);
var f = forwardMul(q, z);
return f;
};
function GradientCheck(){// 梯度检查
var x_derivative = (forward(x+h,y,z) - forward(x,y,z)) / h; // -4
var y_derivative = (forward(x,y+h,z) - forward(x,y,z)) / h; // -4
var z_derivative = (forward(x,y,z+h) - forward(x,y,z)) / h; // 3
}