preOut这一部分就是网络模型前向传播的重点。
public INDArray preOutput(boolean training) {
applyDropOutIfNecessary(training);
INDArray b = getParam(DefaultParamInitializer.BIAS_KEY);
INDArray W = getParam(DefaultParamInitializer.WEIGHT_KEY);
//Input validation:
if (input.rank() != 2 || input.columns() != W.rows()) {
if (input.rank() != 2) {
throw new DL4JInvalidInputException("Input that is not a matrix; expected matrix (rank 2), got rank "
+ input.rank() + " array with shape " + Arrays.toString(input.shape()));
}
throw new DL4JInvalidInputException("Input size (" + input.columns() + " columns; shape = "
+ Arrays.toString(input.shape())
+ ") is invalid: does not match layer input size (layer # inputs = " + W.size(0) + ")");
}
if (conf.isUseDropConnect() && training && conf.getLayer().getDropOut() > 0) {
W = Dropout.applyDropConnect(this, DefaultParamInitializer.WEIGHT_KEY);
}
INDArray ret = input.mmul(W).addiRowVector(b);
if (maskArray != null) {
applyMask(ret);
}
return ret;
}
首先使用applyDropOutIfNecessary(training);
函数判断当前是否使用dropout。
protected void applyDropOutIfNecessary(boolean training) {
if (conf.getLayer().getDropOut() > 0 && !conf.isUseDropConnect() && training && !dropoutApplied) {
input = input.dup();
Dropout.applyDropout(input, conf.getLayer().getDropOut());
dropoutApplied = true;
}
}
使用dropout的条件如下:
- 当前层设置 dropout > 0
- 当前配置没有使用dropConnect(), 这一配置在卷积神经网络常见。
- 当前是训练过程,也就是training的值为true。 在预测的时候dropout不会被应用
- dropout在之前没有被调用。
如果以上条件都满足,则先对当前的输入使用dup()
函数进行复制(注:dup取自单词duplicate,复制的意思),然后传入下一个函数。
/**
5. Apply dropout to the given input
6. and return the drop out mask used
7. @param input the input to do drop out on
8. @param dropout the drop out probability
*/
public static void applyDropout(INDArray input, double dropout) {
if (Nd4j.getRandom().getStatePointer() != null) {
Nd4j.getExecutioner().exec(new DropOutInverted(input, dropout));
} else {
Nd4j.getExecutioner().exec(new LegacyDropOutInverted(input, dropout));
}
}
dropout的实现方式很多,根据这个源码阅读方式发现,dl4j的dropout实现方式是根据截断当前层的输入来实现drpout。
/**
9. This method returns pointer to RNG state structure.
10. Please note: DefaultRandom implementation returns NULL here, making it impossible to use with RandomOps
11. - @return
*/
@Override
public Pointer getStatePointer() {
return statePointer;
}
这个getStatePointer()的目的从代码的注释情况上来还不是很清楚。接下来查看两种实现方式
- DropOutInverted
/**
* Inverted DropOut implementation as Op
*
* @author raver119@gmail.com
*/
public class DropOutInverted extends BaseRandomOp {
private double p;
public DropOutInverted() {