Dl4j-fit(DataSetIterator iterator)源码阅读(九) 利用梯度更新参数

在前面我们已经通过反向传播计算出梯度。
并且梯度也经过梯度标准化,已经L1, L2参数的调整,接下来继续返回到StochasticGradientDescent.optimize()方法中继续执行以下语句

//先获取模型的参数
INDArray params = model.params();
//然后使用StepFunction对参数进行更新
stepFunction.step(params, gradient.gradient());

stepFunction.step(params, gradient.gradient());

这里调用的实体类为:

/**
 * Subtract the line
 *
 * @author Adam Gibson
 */
public class NegativeGradientStepFunction implements StepFunction {
    @Override
    public void step(INDArray x, INDArray line, double step) {
        step(x, line);
    }

    @Override
    public void step(INDArray x, INDArray line) {
        x.subi(line);
    }

    @Override
    public void step() {

    }
}

相当于参数直接减去对应的梯度。

然后再执行

//更新模型参数
model.setParams(params);
//获取模型当前的iteration的次数统计
int iterationCount = BaseOptimizer.getIterationCount(model);
//遍历 监听器
for (IterationListener listener : iterationListeners)
    listener.iterationDone(model, iterationCount);

//判断是否达到某种终止条件,
//如果满足则进行log日志的打印
checkTerminalConditions(pair.getFirst().gradient(), oldScore, score, i);
//增加迭代次数
BaseOptimizer.incrementIterationCount(model, 1);

StepFunction

/**
 * Custom step function for line search
 *
 * @author Adam Gibson
 */
public interface StepFunction extends Serializable {

    /**
     * Step with the given parameters
     * @param x the current parameters
     * @param line the line to step
     * @param step
     */
    void step(INDArray x, INDArray line, double step);


    /**
     * Step with no parameters
     */
    void step(INDArray x, INDArray line);


    void step();

}

checkTerminalConditions

package org.deeplearning4j.optimize.solvers;
/**
 * Base optimizer
 * @author Adam Gibson
 */

@Override
public boolean checkTerminalConditions(INDArray gradient, double oldScore, double score, int i) {
    for (TerminationCondition condition : terminationConditions) {
        if (condition.terminate(score, oldScore, new Object[] {gradient})) {
            log.debug("Hit termination condition on iteration {}: score={}, oldScore={}, condition={}", i, score,
                            oldScore, condition);
            if (condition instanceof EpsTermination && conf.getLayer() != null
                            && conf.getLearningRatePolicy() == LearningRatePolicy.Score) {
                model.applyLearningRateScoreDecay();
            }
            return true;
        }
    }
    return false;
}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值