在前面我们已经通过反向传播计算出梯度。
并且梯度也经过梯度标准化,已经L1, L2参数的调整,接下来继续返回到StochasticGradientDescent.optimize()
方法中继续执行以下语句
//先获取模型的参数
INDArray params = model.params();
//然后使用StepFunction对参数进行更新
stepFunction.step(params, gradient.gradient());
stepFunction.step(params, gradient.gradient());
这里调用的实体类为:
/**
* Subtract the line
*
* @author Adam Gibson
*/
public class NegativeGradientStepFunction implements StepFunction {
@Override
public void step(INDArray x, INDArray line, double step) {
step(x, line);
}
@Override
public void step(INDArray x, INDArray line) {
x.subi(line);
}
@Override
public void step() {
}
}
相当于参数直接减去对应的梯度。
然后再执行
//更新模型参数
model.setParams(params);
//获取模型当前的iteration的次数统计
int iterationCount = BaseOptimizer.getIterationCount(model);
//遍历 监听器
for (IterationListener listener : iterationListeners)
listener.iterationDone(model, iterationCount);
//判断是否达到某种终止条件,
//如果满足则进行log日志的打印
checkTerminalConditions(pair.getFirst().gradient(), oldScore, score, i);
//增加迭代次数
BaseOptimizer.incrementIterationCount(model, 1);
StepFunction
/**
* Custom step function for line search
*
* @author Adam Gibson
*/
public interface StepFunction extends Serializable {
/**
* Step with the given parameters
* @param x the current parameters
* @param line the line to step
* @param step
*/
void step(INDArray x, INDArray line, double step);
/**
* Step with no parameters
*/
void step(INDArray x, INDArray line);
void step();
}
checkTerminalConditions
package org.deeplearning4j.optimize.solvers;
/**
* Base optimizer
* @author Adam Gibson
*/
@Override
public boolean checkTerminalConditions(INDArray gradient, double oldScore, double score, int i) {
for (TerminationCondition condition : terminationConditions) {
if (condition.terminate(score, oldScore, new Object[] {gradient})) {
log.debug("Hit termination condition on iteration {}: score={}, oldScore={}, condition={}", i, score,
oldScore, condition);
if (condition instanceof EpsTermination && conf.getLayer() != null
&& conf.getLearningRatePolicy() == LearningRatePolicy.Score) {
model.applyLearningRateScoreDecay();
}
return true;
}
}
return false;
}