【lightGBM】在java中使用lightGBM进行预测,实例使用lightGBM对股票进行预测

在java中使用lightGBM进行预测,实例使用lightGBM对股票进行预测

整体流程

1.读取数据,读取csv的数据并转化为JSONArray
2.对数据进行训练
3.对测试组数据进行预测,并分析预测准确率

读取数据(CVS转JSONArry)

  1. csv的数据例子:
    lable:是否上涨(1代表上涨,0代表下跌)
    overMa5:是否大于5天线
    overMa7:是否大于7天线
    overMin5:是否大于近5天最低价位
//读取csv转为JSONArry
private static JSONArray getJsonArray(String file) {
        JSONArray jsonArray = new JSONArray();
        try {
            InputStream is = new FileInputStream(PlayLgbAI.class.getClassLoader().getResource("").getPath() + file);
            List<String> list = IOUtils.readLines(is, "utf-8");
            String header = list.get(0);
            String[] headerColumn = header.split(",");

            for (int i = 1; i < list.size(); i++) {
                String line = list.get(i);
                String[] lineData = line.split(",");
                JSONObject jsonObject = new JSONObject();
                for (int j = 0; j < lineData.length; j++) {
                    jsonObject.put(headerColumn[j], lineData[j]);
                }
                jsonArray.add(jsonObject);
            }
        } catch (FileNotFoundException e) {
            throw new RuntimeException(e);
        }
        return jsonArray;
    }

读取数据(JSONArray转为模型的对像)

2.定义一个清洗后的结构

  public static class CleanData {
  		//数据
        private JSONArray jsonArray;
        //分析的列
        private Set<String> columns = Sets.newLinkedHashSet();

		//值
        List<Double> valuesList = Lists.newArrayList();
        //训练的正向结果,即csv中lable对应的值
        List<Double> lables = Lists.newArrayList();

		//预计的结果
        List<Double> preds = Lists.newArrayList();
		//getter and setter
        }
  1. 把读取的jsonArry转为CleanData,并设置预测用的因子
 public static CleanData cleanData(JSONArray jsonArray) {
        CleanData cleanData = new CleanData();
        Set<String> columns = cleanData.getColumns();
        columns.add("overMa5");
        columns.add("overMa7");
        columns.add("overMa14");
        columns.add("overMa20");
        columns.add("overMin5");
        columns.add("overMin7");
        columns.add("overMin14");
        columns.add("overMin20");
        columns.add("overMin60");
        columns.add("ret10");
        columns.add("ret20");
        columns.add("ret60");
        columns.add("volume");
        for (int i = 0; i < jsonArray.size(); i++) {
            JSONObject jsonObject = jsonArray.getJSONObject(i);
            for (String column : columns) {
                Double aDouble = jsonObject.getDouble(column);
                cleanData.getValuesList().add(aDouble != null ? aDouble : 0);
            }
            cleanData.getLables().add(jsonObject.getDouble("lable"));
        }
        return cleanData;
    }

进行训练并返回模型

  1. 进行训练并返回模型

LGBMDataset.createFromMat(values, labels.length, cleanData.getColumns().size(), true, “”, null) 方法中values代表所有值,labels.length=行数, cleanData.getColumns().size()=列
例如:
values=ma5,ma7,min7,min14,ma5,ma7,min7,min14
行数应该为2
列数应该为4

 public static LGBMBooster train(JSONArray jsonArray) throws LGBMException {
 		//jsonArray转为CleanData
        CleanData cleanData = cleanData(jsonArray);
        double[] values = cleanData.getValuesList().stream().mapToDouble(Double::doubleValue).toArray();

        double[] labels = cleanData.getLables().stream().mapToDouble(Double::doubleValue).toArray();
        float[] lablelsFloat = new float[labels.length];
        for (int i = 0; i < labels.length; i++) {
            lablelsFloat[i] = new Float(labels[i]);
        }
        LGBMDataset dataset = LGBMDataset.createFromMat(values, labels.length, cleanData.getColumns().size(), true, "", null);
        dataset.setFeatureNames(cleanData.getColumns().toArray(new String[0]));
        dataset.setField("label", lablelsFloat);

        String parameters = "objective=regression label=name:Classification metric=mae";

        LGBMBooster booster = LGBMBooster.create(dataset, parameters);
        for (int i = 0; i < 100; i++) {
            System.out.println("第" + i + "次");
            booster.updateOneIter();

            String[] names = booster.getFeatureNames();
            double[] weights = booster.featureImportance(0, LGBMBooster.FeatureImportanceType.GAIN);
            assertTrue(names.length > 0);
            assertTrue(weights.length > 0);
            System.out.println(JSONObject.toJSONString(names));
            System.out.println(JSONObject.toJSONString(weights));
        }
        return booster;
    }

输出:第99次训练以后,输出得到因子的重要性排序,由上面例子可得知
大于最近5天最低价是最大因子
量能是第二大的因子
ret20:20天上涨比例是第三大的因子
第99次
[“overMa5”,“overMa7”,“overMa14”,“overMa20”,“overMin5”,“overMin7”,“overMin14”,“overMin20”,“overMin60”,“ret10”,“ret20”,“ret60”,“volume”]
[79.4154430180788,38.42058273404837,69.61820960044861,6.021978914737701,1345.2684633247554,2.600044012069702,55.968476973474026,0.5242173969745636,13.071800041943789,274.15874949097633,232.41601605527103,198.7524663899094,265.0022159293294]

  1. 使用模型对测试组数据进行预测并显示成功率
public static void run(JSONArray jsonArray) throws LGBMException {

        LGBMBooster lgbmBooster = train(jsonArray);

        JSONArray testJsonArray = getJsonArray("/lightgbm/st_etf_ai_test.csv");
        CleanData cleanData = cleanData(testJsonArray);
        int n = 0;
        int index = (cleanData.getValuesList().size() - n * cleanData.getColumns().size());
        int row = cleanData.getValuesList().size() / cleanData.getColumns().size();
        List<Double> inputList = cleanData.getValuesList();
        float[] input = new float[inputList.size()];
        for (int i = 0; i < inputList.size(); i++) {
            input[i] = new Float(inputList.get(i));
        }

        double[] pred = lgbmBooster.predictForMat(input, row, cleanData.getColumns().size(), true, PredictionType.C_API_PREDICT_RAW_SCORE);

        for (int i = 0; i < pred.length; i++) {
            cleanData.getPreds().add(pred[i]);
        }
        int sum = 0;
        int successCount = 0;
        for (int i = 0; i < pred.length; i++) {
    		//匹配率达到7成以上的认为上涨
            if (cleanData.getPreds().get(i) > 0.7) {
                sum = sum + 1;
                //预测值与结果值都为上涨时,成功次数加1
                if (cleanData.getLables().get(i) > 0) {
                    successCount = successCount + 1;
                }
            }
        }
        System.out.println("总量" + sum + ",成功:" + successCount);
        System.out.println("准确率:" + successCount / Double.valueOf(sum) * 100 + "%");
    }

总量44,成功:32
准确率:72.72727272727273%

  • 19
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值