4360人阅读 评论(0)

郑重声明：股市有风险，投资需谨慎，利用本模型实盘请自行承担风险

package implementation;

import java.util.List;

public /**
* 你可以根据自己的需要，继承这个类，实现自己的弱分类器，只需要实现两个方法即可
* @author zhangshiming
*/
abstract class WeakClassifier{
public static final int RIGHT = 1;
public static final int WRONG = 0;
public double weight;//alpha
public final double calculateErrorPositive(double[][] inputX, double[] inputY, int[] rightOrWrong){
double errorTimes = 0;//预测错误的次数
double pnum = 0;
for(int i = 0; i < inputX.length; i++){
if(inputY[i] == 1){
pnum++;
int res = predict(inputX[i], inputY[i]);
if(res == WRONG){
errorTimes++;
}
rightOrWrong[i] = res;
}
}

return errorTimes / pnum;//错误率
}

public final double calculateErrorNegative(double[][] inputX, double[] inputY, int[] rightOrWrong){
double errorTimes = 0;//预测错误的次数
double nnum = 0;
for(int i = 0; i < inputX.length; i++){
if(inputY[i] == -1){
nnum++;
int res = predict(inputX[i], inputY[i]);
if(res == WRONG){
errorTimes++;
}
rightOrWrong[i] = res;
}

}

return errorTimes / nnum;//错误率
}

public final double calculateError(double[][] inputX, double[] inputY, int[] rightOrWrong){
double errorTimes = 0;//预测错误的次数
for(int i = 0; i < inputX.length; i++){
int res = predict(inputX[i], inputY[i]);
if(res == WRONG){
errorTimes++;
}
rightOrWrong[i] = res;
}

return errorTimes / inputY.length;//错误率
}
public final double calculateIR(double[][] inputX, double[] inputY, int[] rightOrWrong,List<double[]> irlist){
double sumIr = 0;
for(int i = 0; i < inputX.length; i++){
if(inputY[i] > 0 && rightOrWrong[i] == WeakClassifier.RIGHT){
sumIr += irlist.get(i)[0];
}

if(inputY[i] < 0 && rightOrWrong[i] == WeakClassifier.WRONG){
sumIr += irlist.get(i)[0];
}
}

return sumIr;
}

//预测正确返回RIGHT,错误返回WRONG
public final int predict(double[] x, double y){
double res = predict(x);

//System.out.println(res);s

if(res == y){
return RIGHT;
}else{
return WRONG;
}
}

public abstract double predict(double[] x);

public abstract void train(double[][] inputX, double[] inputY, double[] weights);

}


package implementation;

public double mBias;
private static final int maxBias = 20;
private static final int minBias = -maxBias;
private static double lastThreadHold = -10000;
private static double lastBias = -10000;
@Override
public double predict(double[] x) {
if(mThreadHold * x[0] + mBias >= 0){
return 1;
}else{
return -1;
}
}

@Override
public String toString() {
}

@Override
public void train(double[][] inputX, double[] inputY, double[] weights) {
double max = 0, min = 0;
double step = 0.1;
// find max min
for(int i = 0; i < inputX.length; i++){
double val = inputX[i][0];
if(val > max){
max = val;
}
if(val < min){
min = val;
}
}
lastBias = minBias;
}else{
if(lastBias == -10000){
lastBias = minBias;
}else{
lastBias += 0.1;
if(lastBias > maxBias){
lastBias = minBias;
}
}

}
mBias = lastBias;
//System.out.println("bias = " + mBias + " threashHold = " + mThreadHold);
}

}

package implementation;

import java.util.ArrayList;
import java.util.List;

private double[][] mInputX = null;//样本
private double[] mInputY = null;//样本标签
private double[] mWeights = null;//样本权重
private int mSampleNum = -1;
private List<WeakClassifier> mWeakClassifierSet = new ArrayList<WeakClassifier>();

setInput(X, Y);//构造函数，初始化训练样本，和标签1，-1
}

if(input == null || input.length == 0){
new RuntimeException("no input data, please check !");
}
final int cols = input[0].length - 1;
double[][] X = new double[input.length][cols];
double[] Y = new double[input.length];
for(int i = 0; i < input.length; i++){
for(int j = 0; j < input[i].length; j++){
if(j < input[i].length -1){
X[i][j] = input[i][j];
}else{
Y[i] = input[i][j];
}
}

}
setInput(X, Y);
}
public void setInput(double[][] X, double[] Y){
if(X == null || Y == null){
throw new RuntimeException(
"input X or input Y can not be null, please check!");
}

if(X.length != Y.length){
throw new RuntimeException(
"input X or input Y belongs to different dimension, please check!");
}

mInputX = X;
mInputY = Y;
mSampleNum = mInputX.length;
mWeights = new double[mSampleNum];
}

private void initWeights(){
for(int i = 0; i < mSampleNum; i++){
mWeights[i] = 1.0 / mSampleNum;
}
}

double res = 0;
if(mWeakClassifierSet.size() == 0){
throw new RuntimeException(
"no weak classifiers !!");
}
for(int i = 0; i < mWeakClassifierSet.size(); i++){
res += mWeakClassifierSet.get(i).weight *
mWeakClassifierSet.get(i).predict(x);
}
return res;
}

private void updateWeights(int[] rightOrWrong, double alpha){
//更新样本权重，被分错的样本总是具有很大的权重，读者可自行根据权重来特殊训练这些容易被分错的样本
double Z = 0;
for(int i = 0; i < rightOrWrong.length; i++){
if(rightOrWrong[i] == WeakClassifier.RIGHT){
mWeights[i] *= Math.exp(-alpha);
}else if(rightOrWrong[i] == WeakClassifier.WRONG){
mWeights[i] *= Math.exp(alpha);
}else{
throw new RuntimeException(
"unknown right or wrong flag, please check!");
}

Z += mWeights[i];
}

//权重归一化
for(int i = 0; i < rightOrWrong.length; i++){
mWeights[i] /= Z;
}
}

//这个方法是核心，也就是寻找合格的若分类器，并保存在一个List中
public void trainWeakClassifiers(int epoch,List<double[]> irlist){
if(epoch <= 1){
throw new RuntimeException(
"training epoch must be greater than 1, please check!");
}

System.out.println("start training......");

initWeights();//初始化样本权重

for(int i = 0; i < epoch; i++){
weakClassifier.train(mInputX, mInputY, mWeights);
int[] rightOrWrong = new int[mSampleNum];// 1 right, 0 wrong
double errorP = weakClassifier.calculateErrorPositive(mInputX, mInputY,rightOrWrong);//计算正样本错误率
double errorN = weakClassifier.calculateErrorNegative(mInputX, mInputY,rightOrWrong);//计算负样本错误率
double error = weakClassifier.calculateError(mInputX, mInputY,rightOrWrong);//计算总体错误率
double sumIr = weakClassifier.calculateIR(mInputX, mInputY,rightOrWrong,irlist);//计算利润率
//System.out.println("perror = " + errorP + " nerror = " + errorN + " error = " + error);
if(errorP > 0.5 || errorN > 0.5){
continue;//不满足错误率的分类器就抛弃
}

//if(sumIr <=0){
//	continue;
//}
//System.out.println("perror = " + errorP + " nerror = " + errorN);

double alpha = Math.log((1 - error) / error) / 2;
weakClassifier.weight = alpha;//保存若分类器的权重
updateWeights(rightOrWrong, alpha);
System.out.println("epoch " + i +
" got one weak classifier, haha " + weakClassifier.toString() + " error=" + error + " ir=" + sumIr);
}
System.out.println("train finish!!  " + mWeakClassifierSet.size() + " weak classifier(s) was trained !!");
//		for(int i = 0; i < mWeakClassifierSet.size(); i++){
//		}
}
}


package implementation;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
* @author zhangshiming
* @email 106601549@qq.com
*
*/
public class Main {

public static void main(String[] args) throws Exception {
int epoch = 110000;//训练次数
//初始化训练样本,最后一列为标签
String linestr = null;
boolean flag = true;

List<double[]> trainData = new ArrayList<double[]>();
List<double[]> trainRate = new ArrayList<double[]>();
List<double[]> testData = new ArrayList<double[]>();
List<double[]> testRate = new ArrayList<double[]>();
if(linestr.startsWith("test")){
flag = false;
continue;
}
String[] s = linestr.split("\t");
double[] darr = new double[2];
darr[0] = Double.valueOf(s[1]);
darr[1] = Double.valueOf(s[2]);
double[] rarr = new double[1];
rarr[0] = Double.valueOf(s[3]);
if(flag){
//读入输入数据
}else{
//读入测试数据

}

}
bfr.close();
final int trainRows = trainData.size();
final int testRows = testData.size();
double[][] X = new double[trainRows][trainData.get(0).length];
double[][] testInput = new double[testRows][testData.get(0).length];
for(int i = 0; i < trainRows; i++){
X[i][0] = trainData.get(i)[0];
X[i][1] = trainData.get(i)[1];

}
for(int i = 0; i < testRows; i++){
testInput[i][0] = testData.get(i)[0];
testInput[i][1] = testData.get(i)[1];
}
trainData.clear();
testData.clear();

if(testInput == null || testInput.length == 0){
new RuntimeException("no input data, please check !");
}
final int cols = testInput[0].length - 1;
double testX[][] = new double[testInput.length][cols];
double testY[] = new double[testInput.length];
for(int i = 0; i < testInput.length; i++){
for(int j = 0; j < testInput[i].length; j++){
if(j < testInput[i].length -1){
testX[i][j] = testInput[i][j];
}else{
testY[i] = testInput[i][j];
}
}

}

double testErrorTimes = 0;
double total = 0;
double testErrorTimes1 = 0;
double total1 = 0;
double ir = 0;
double ir1 = 0;
for(int i = 0; i < testX.length; i++){
/******************正做*******************/
if(testY[i] > 0 && res >= 0){
ir += testRate.get(i)[0];
System.out.println("正做:" + testRate.get(i)[0]);
total++;
}

if(testY[i] < 0 && res >= 0){
ir += testRate.get(i)[0];
System.out.println("正做:" + testRate.get(i)[0]);
testErrorTimes++;
total++;
}
/******************反做*******************/
if(testY[i] > 0 && res < 0){
ir1 += testRate.get(i)[0];
System.out.println("反做:" + testRate.get(i)[0]);
testErrorTimes1++;
total1++;
}

if(testY[i] < 0 && res < 0){
ir1 += testRate.get(i)[0];
System.out.println("反做:" + testRate.get(i)[0]);
total1++;
}

System.out.println("在测试数据上的IR=" + ir + "    IR1=" + ir1);
System.out.println();
}
System.out.println();
System.out.println("在测试数据上的IR=" + ir + " error=" + (testErrorTimes/total*100));
System.out.println("在测试数据上的IR1=" + ir1 + " error=" + (testErrorTimes1/total1*100));

}
}



2015/06/11    -16.51    -1    -3.088
2015/06/12    -22.534    1    6.735
2015/06/15    2.576    -1    -7.579
2015/06/16    -21.514    -1    -5.374
2015/06/17    -28.798    1    2.545
2015/06/18    -11.445    -1    -0.939
2015/06/19    -8.888    -1    -2.741
2015/06/23    -11.124    1    0.404
2015/06/24    -3.842    1    6.566
testData
2015/06/25    17.84    -1    -8.508
2015/06/26    -8.278    -1    -11.105
2015/06/29    -27.741    -1    -11.139

testData后的数据会被读入测试样本中，之前的数据用于训练，第一列是日期，第二列是昨天与前天的macd之差，第三列是当天涨跌，1涨，-1跌，第四列是当天涨跌点数

IR是正做，IR1是反做，就是如果模型输出的res>0就是正做，买入，res<0就是反做，本应该卖出，可是我们买入，这就是反做

个人资料
等级：
访问量： 11万+
积分： 1306
排名： 3万+
最新评论