深度学习-用户自定义层

这篇我们看下如果自定义层,如何做梯度测试,老吴还是很推崇梯度测试的,由于已经写过很多篇了,重复的内容不再解释,可以参考前面的文章

public class CustomLayerExample {

    static{//静态代码块
        //Double precision for the gradient checks. See comments in the doGradientCheck() method//梯度检测以小数为精度
        // See also http://nd4j.org/userguide.html#miscdatatype
        DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);//设置nd4j分配数据模式,这里是双精度小数

    }

    public static void main(String[] args) throws IOException {
        runInitialTests();//连续调用
        doGradientCheck();
    }

    private static void runInitialTests() throws IOException {//运行初始测试
        /*
        This method shows the configuration and use of the custom layer.
        It also shows some basic sanity checks and tests for the layer.
        In practice, these tests should be implemented as unit tests; for simplicity, we are just printing the results
         */

        System.out.println("----- Starting Initial Tests -----");

        int nIn = 5;
        int nOut = 8;

        //Let's create a network with our custom layer

        MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
            .updater(Updater.RMSPROP).rmsDecay(0.95)//加速衰减系数,防止梯度变化过大,训练过早结束的参数更新方法
            .weightInit(WeightInit.XAVIER)
            .regularization(true).l2(0.03)
            .list()
            .layer(0, new DenseLayer.Builder().activation("tanh").nIn(nIn).nOut(6).build())     //Standard DenseLayer
            .layer(1, new CustomLayer.Builder()//这里是用户自定义层,后面我们会有相关代码的解释
                .activation("tanh")                                                             //Property inherited from FeedForwardLayer
                .secondActivationFunction("sigmoid")                                            //Custom property we defined for our layer
                .nIn(6).nOut(7)                                                                 //nIn and nOut also inherited from FeedForwardLayer
                .build())
            .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)//使用交叉熵作为损失函数               //Standard OutputLayer
                .activation("softmax").nIn(7).nOut(nOut).build())
            .pretrain(false).backprop(true).build();


        //First:  run some basic sanity checks on the configuration://第一步:检查配置是否正确
        double customLayerL2 = config.getConf(1).getLayer().getL2();//.getConf(1)会获得3层网络的第2层配置,getLayer().getL2()会获得该层的正则化参数

        System.out.println("l2 coefficient for custom layer: " + customLayerL2);                //As expected: custom layer inherits the global L2 parameter configuration
        Updater customLayerUpdater = config.getConf(1).getLayer().getUpdater();//.getConf(1)会获得3层网络的第2层配置,再getLayer().getUpdater()会获得该层的参数更新器


        System.out.println("Updater for custom layer: " + customLayerUpdater);                  //As expected: custom layer inherits the global Updater configuration

        //Second: We need to ensure that that the JSON and YAML configuration works, with the custom layer
        // If there were problems with serialization, you'd get an exception during deserialization ("No suitable constructor found..." for example)//第二步:验证json和yaml配置可用,如果序列化有问题,反序列会报错
        String configAsJson = config.toJson();//配置转成json和yaml字符串
        String configAsYaml = config.toYaml();
        MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(configAsJson);//再获取json和yaml的配置
        MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(configAsYaml);

        System.out.println("JSON configuration works: " + config.equals(fromJson));
        System.out.println("YAML configuration works: " + config.equals(fromYaml));

        MultiLayerNetwork net = new MultiLayerNetwork(config);
        net.init();


        //Third: Let's run some more basic tests. First, check that the forward and backward pass methods don't throw any exceptions
        // To do this: we'll create some simple test data//第三步:基础测试,验证模型的前向后向反馈是否异常
        int minibatchSize = 5;
        INDArray testFeatures = Nd4j.rand(minibatchSize, nIn);//随机生成0,1之间的小数作为测试集属性
        INDArray testLabels = Nd4j.zeros(minibatchSize, nOut)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值