在java中使用lightGBM进行预测,实例使用lightGBM对股票进行预测
整体流程
1.读取数据,读取csv的数据并转化为JSONArray
2.对数据进行训练
3.对测试组数据进行预测,并分析预测准确率
读取数据(CVS转JSONArry)
- csv的数据例子:
lable:是否上涨(1代表上涨,0代表下跌)
overMa5:是否大于5天线
overMa7:是否大于7天线
overMin5:是否大于近5天最低价位
//读取csv转为JSONArry
private static JSONArray getJsonArray(String file) {
JSONArray jsonArray = new JSONArray();
try {
InputStream is = new FileInputStream(PlayLgbAI.class.getClassLoader().getResource("").getPath() + file);
List<String> list = IOUtils.readLines(is, "utf-8");
String header = list.get(0);
String[] headerColumn = header.split(",");
for (int i = 1; i < list.size(); i++) {
String line = list.get(i);
String[] lineData = line.split(",");
JSONObject jsonObject = new JSONObject();
for (int j = 0; j < lineData.length; j++) {
jsonObject.put(headerColumn[j], lineData[j]);
}
jsonArray.add(jsonObject);
}
} catch (FileNotFoundException e) {
throw new RuntimeException(e);
}
return jsonArray;
}
读取数据(JSONArray转为模型的对像)
2.定义一个清洗后的结构
public static class CleanData {
//数据
private JSONArray jsonArray;
//分析的列
private Set<String> columns = Sets.newLinkedHashSet();
//值
List<Double> valuesList = Lists.newArrayList();
//训练的正向结果,即csv中lable对应的值
List<Double> lables = Lists.newArrayList();
//预计的结果
List<Double> preds = Lists.newArrayList();
//getter and setter
}
- 把读取的jsonArry转为CleanData,并设置预测用的因子
public static CleanData cleanData(JSONArray jsonArray) {
CleanData cleanData = new CleanData();
Set<String> columns = cleanData.getColumns();
columns.add("overMa5");
columns.add("overMa7");
columns.add("overMa14");
columns.add("overMa20");
columns.add("overMin5");
columns.add("overMin7");
columns.add("overMin14");
columns.add("overMin20");
columns.add("overMin60");
columns.add("ret10");
columns.add("ret20");
columns.add("ret60");
columns.add("volume");
for (int i = 0; i < jsonArray.size(); i++) {
JSONObject jsonObject = jsonArray.getJSONObject(i);
for (String column : columns) {
Double aDouble = jsonObject.getDouble(column);
cleanData.getValuesList().add(aDouble != null ? aDouble : 0);
}
cleanData.getLables().add(jsonObject.getDouble("lable"));
}
return cleanData;
}
进行训练并返回模型
- 进行训练并返回模型
LGBMDataset.createFromMat(values, labels.length, cleanData.getColumns().size(), true, “”, null) 方法中values代表所有值,labels.length=行数, cleanData.getColumns().size()=列
例如:
values=ma5,ma7,min7,min14,ma5,ma7,min7,min14
行数应该为2
列数应该为4
public static LGBMBooster train(JSONArray jsonArray) throws LGBMException {
//jsonArray转为CleanData
CleanData cleanData = cleanData(jsonArray);
double[] values = cleanData.getValuesList().stream().mapToDouble(Double::doubleValue).toArray();
double[] labels = cleanData.getLables().stream().mapToDouble(Double::doubleValue).toArray();
float[] lablelsFloat = new float[labels.length];
for (int i = 0; i < labels.length; i++) {
lablelsFloat[i] = new Float(labels[i]);
}
LGBMDataset dataset = LGBMDataset.createFromMat(values, labels.length, cleanData.getColumns().size(), true, "", null);
dataset.setFeatureNames(cleanData.getColumns().toArray(new String[0]));
dataset.setField("label", lablelsFloat);
String parameters = "objective=regression label=name:Classification metric=mae";
LGBMBooster booster = LGBMBooster.create(dataset, parameters);
for (int i = 0; i < 100; i++) {
System.out.println("第" + i + "次");
booster.updateOneIter();
String[] names = booster.getFeatureNames();
double[] weights = booster.featureImportance(0, LGBMBooster.FeatureImportanceType.GAIN);
assertTrue(names.length > 0);
assertTrue(weights.length > 0);
System.out.println(JSONObject.toJSONString(names));
System.out.println(JSONObject.toJSONString(weights));
}
return booster;
}
输出:第99次训练以后,输出得到因子的重要性排序,由上面例子可得知
大于最近5天最低价是最大因子
量能是第二大的因子
ret20:20天上涨比例是第三大的因子
第99次
[“overMa5”,“overMa7”,“overMa14”,“overMa20”,“overMin5”,“overMin7”,“overMin14”,“overMin20”,“overMin60”,“ret10”,“ret20”,“ret60”,“volume”]
[79.4154430180788,38.42058273404837,69.61820960044861,6.021978914737701,1345.2684633247554,2.600044012069702,55.968476973474026,0.5242173969745636,13.071800041943789,274.15874949097633,232.41601605527103,198.7524663899094,265.0022159293294]
- 使用模型对测试组数据进行预测并显示成功率
public static void run(JSONArray jsonArray) throws LGBMException {
LGBMBooster lgbmBooster = train(jsonArray);
JSONArray testJsonArray = getJsonArray("/lightgbm/st_etf_ai_test.csv");
CleanData cleanData = cleanData(testJsonArray);
int n = 0;
int index = (cleanData.getValuesList().size() - n * cleanData.getColumns().size());
int row = cleanData.getValuesList().size() / cleanData.getColumns().size();
List<Double> inputList = cleanData.getValuesList();
float[] input = new float[inputList.size()];
for (int i = 0; i < inputList.size(); i++) {
input[i] = new Float(inputList.get(i));
}
double[] pred = lgbmBooster.predictForMat(input, row, cleanData.getColumns().size(), true, PredictionType.C_API_PREDICT_RAW_SCORE);
for (int i = 0; i < pred.length; i++) {
cleanData.getPreds().add(pred[i]);
}
int sum = 0;
int successCount = 0;
for (int i = 0; i < pred.length; i++) {
//匹配率达到7成以上的认为上涨
if (cleanData.getPreds().get(i) > 0.7) {
sum = sum + 1;
//预测值与结果值都为上涨时,成功次数加1
if (cleanData.getLables().get(i) > 0) {
successCount = successCount + 1;
}
}
}
System.out.println("总量" + sum + ",成功:" + successCount);
System.out.println("准确率:" + successCount / Double.valueOf(sum) * 100 + "%");
}
总量44,成功:32
准确率:72.72727272727273%