我们来看一下caffe中具体是怎样实现的,代码位于include/caffe/filler.hpp文件中。
template <typename Dtype>
class XavierFiller : public Filler<Dtype> {
public:
explicit XavierFiller(const FillerParameter& param)
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
int fan_in = blob->count() / blob->num();
int fan_out = blob->count() / blob->channels();
Dtype n = fan_in; // default to fan_in
if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_AVERAGE) {
n = (fan_in + fan_out) / Dtype(2);
} else if (this->filler_param_.variance_norm() ==
FillerParameter_VarianceNorm_FAN_OUT) {
n = fan_out;
}
Dtype scale = sqrt(Dtype(3) / n);
caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
blob->mutable_cpu_data());
CHECK_EQ(this->filler_param_.sparse(), -1)
<< "Sparsity not supported by this Filler.";
}
};
由上面可以看出,caffe的Xavier实现有三种选择:
(1) 默认情况,方差只考虑输入个数:
(2) FillerParameter_VarianceNorm_FAN_OUT,方差只考虑输出个数:
(3) FillerParameter_VarianceNorm_AVERAGE,方差同时考虑输入和输出个数: