/**
* 最小二乘回归树 缺失属性还没想好怎么处理好
* @author ysh 1208706282
*
*/
public class Cart {
static double MISSINGDATA = -111111111;
int mMaxDepth; //设定的最大深度
int mMinLeaf; //节点最小样本数
double mFeatureRate; //属性选择比率
List<Sample> mSamples;
Random mRandom;
Node mParent; //回归树根节点
static class Sample{
Double label;
List<Double> feature;
}
static class Node{
List<Sample> samples;
int depth;
int featureId;
double splitValue;
double fitness;
double predict;
Node childs[];
boolean leaf;
}
/**
* 加载数据 回归树
* @param path
* @param regex
* @throws Exception
*/
public void loadData(String path,String regex) throws Exception{
mSamples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
sample.label = Double.valueOf(splits[0]);
sample.feature = new ArrayList<Double>(splits.length-1);
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Double(splits[i+1]));
}
mSamples.add(sample);
}
reader.close();
}
public void setData(List<Sample> samples){
this.mSamples = samples;
}
/**
* 加载验证测试集
* @param path
* @param regex
* @throws Exception
*/
public static List<Sample> loadTestData(String path,boolean hasLabel,String regex) throws Exception{
List<Sample> samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
if(hasLabel){
sample.label = Double.valueOf(splits[0]);
}
sample.feature = new ArrayList<Double>(splits.length-1);
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Double(splits[i+1]));
}
samples.add(sample);
}
reader.close();
return samples;
}
/**
* 求节点均值
* @param samples
* @return
*/
double getAverage(Sample []samples){
double avg = 0;
for(Sample sample:samples){
avg += sample.label;
}
return avg/samples.length;
}
/**
* 判断样本标签是否一致
* @param samples
* @return
*/
boolean isSame(Sample []samples){
double label = samples[0].label;
boolean issame = true;
for(Sample sample:samples){
if((label - sample.label)<1E-5){
issame = false;
break;
}
}
return issame;
}
/**
* 指定特征的不纯度 越小越好
* @param samples
* @param featIndex
* @param node
* @return
*/
int getImpurity(Sample[] samples,final int featIndex,Node node){
long start = System.currentTimeMillis();
//System.out.println("getImpurity "+featIndex+" "+start);
int ret = 0;
Arrays.sort(samples, new Comparator<Sample>(){
@Override
public int compare(Sample o1, Sample o2) {
// TODO Auto-generated method stub
int ret = 0;
if(o1.feature.get(featIndex) < o2.feature.get(featIndex)){
ret = -1;
}else{
ret = 1;
}
return ret;
}});
double ts = 0;
double ls = 0;
double rs = 0;
double avgSplit = 0;
for(Sample s:samples){
ts += s.label;
}
ls += samples[0].label;
double bestFitness = Double.MAX_VALUE;
double bestSplit = 0;
for(int i=1;i<samples.length;i++){
ls += samples[i].label;
if((samples[i].feature.get(featIndex)-samples[i-1].feature.get(featIndex))<1E-4){
continue;
}
//System.out.println("getImpurity "+featIndex+" "+(System.currentTimeMillis()-start)/1000);
ls -= samples[i].label;
double lavg = ls/i;
double lerror = 0;
for(int j=0;j<i;j++){
lerror += (samples[j].label-lavg)*(samples[j].label-lavg);
}
double ravg = (ts-ls)/(samples.length-i);
double rerror = 0;
for(int j=i;j<samples.length;j++){
rerror += (samples[j].label-ravg)*(samples[j].label-ravg);
}
if(bestFitness > (lerror+rerror)){
bestFitness = lerror+rerror;
bestSplit = (samples[i].feature.get(featIndex) + samples[i-1].feature.get(featIndex))/2;
}
}
node.fitness = bestFitness;
node.splitValue = bestSplit;
return bestFitness!=Double.MAX_VALUE? 0:1;
}
/**
* 找到最佳切分属性及其分割点
* @param samples
* @param node
* @return
*/
int findSplit(Sample []samples,Node node){
int ret = 0;
int featureIndex[] = new int[samples[0].feature.size()];
for(int i=0;i<samples[0].feature.size();i++){
featureIndex[i] = i;
}
int index = -1;
for(int i=0;i<samples[0].feature.size();i++){
index = mRandom.nextInt(samples[0].feature.size());
featureIndex[i] = featureIndex[i]^featureIndex[index];
featureIndex[index] = featureIndex[i]^featureIndex[index];
featureIndex[i] = featureIndex[i]^featureIndex[index];
}
int bestFeatIdx = 0;
double bestFitness = Double.MAX_VALUE;
double bestSplitValue = 0;
for(int feat=0;feat<featureIndex.length*mFeatureRate;feat++){
int idx = featureIndex[feat];
ret = getImpurity(samples,idx,node);
if(ret != 0){
continue;
}
if(bestFitness > node.fitness){
bestFitness = node.fitness;
bestFeatIdx = idx;
bestSplitValue = node.splitValue;
}
}
node.fitness = bestFitness;
node.featureId = bestFeatIdx;
node.splitValue = bestSplitValue;
return bestFitness!=Double.MAX_VALUE ? 0:1;
}
/**
* 分割数据
* @param samples
* @param node
*/
public void splitData(Sample []samples,Node node){
node.childs = new Node[3];
for(int i=0;i<3;i++){
node.childs[i] = new Node();
node.childs[i].depth = node.depth+1;
node.childs[i].samples = new ArrayList<Sample>();
}
int feat = node.featureId;
for(Sample s:samples){
if(s.feature.get(feat) == Cart.MISSINGDATA){
node.childs[2].samples.add(s);
continue;
}
if(s.feature.get(feat) < node.splitValue){
node.childs[0].samples.add(s);
}else{
node.childs[1].samples.add(s);
}
}
}
/**
* 递归训练创建
* @param samples
* @param node
*/
public void fit(Sample []samples,Node node){
node.predict = getAverage(samples);
if((node.depth==mMaxDepth) || isSame(samples) || samples.length<mMinLeaf){
node.leaf = true;
return;
}
int ret = 0;
ret = findSplit(samples,node);
if(ret != 0){
node.leaf = true;
return;
}
splitData(samples,node);
if(node.childs[0].samples.isEmpty() || node.childs[1].samples.isEmpty()){
node.leaf = true;
return;
}
Sample []s = null;
for(int i=0;i<3;i++){
s = new Sample[node.childs[i].samples.size()];
for(int j=0;j<s.length;j++){
s[j] = node.childs[i].samples.get(j);
}
if(s.length != 0){
fit(s,node.childs[i]);
}
}
}
/**
* 训练
*/
public void train(){
mParent = new Node();
mParent.samples = mSamples;
mParent.depth = 0;
Sample []s = new Sample[mSamples.size()];
for(int i=0;i<s.length;i++){
s[i] = mSamples.get(i);
}
fit(s,mParent);
}
/**
* 分类
* @param sample
* @return
*/
public double classify(Sample sample){
return classify(mParent,sample);
}
/**
* 分类
* @param node
* @param sample
* @return
*/
public double classify(Node node,Sample sample){
if(node.leaf == true){
return node.predict;
}
int fea = node.featureId;
if(sample.feature.get(fea) == Cart.MISSINGDATA){
return classify(node.childs[2],sample);
}
if(sample.feature.get(fea) < node.splitValue){
return classify(node.childs[0],sample);
}else{
return classify(node.childs[1],sample);
}
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
// TODO Auto-generated method stub
Random ran = new Random();
ran.setSeed(10001);
for(int i=0;i<10;i++){
System.out.println(ran.nextInt(10));
}
Cart cart = new Cart();
cart.mFeatureRate = 0.8;
cart.mMaxDepth = 6;
cart.mMinLeaf = 1;
cart.mRandom = new Random();
cart.mRandom.setSeed(100);
cart.loadData("F:/2016-contest/20161001/train_data_1.csv", ",");
System.out.println(System.currentTimeMillis());
cart.train();
List<Sample> samples = cart.loadTestData("F:/2016-contest/20161001/valid_data_1.csv", true, ",");
double sum = 0;
for(Sample s:samples){
double val = cart.classify(s);
sum += (val-s.label)*(val-s.label);
System.out.println(cart.classify(s)+" "+s.label);
}
System.out.println(sum/samples.size());
System.out.println(System.currentTimeMillis());
}
}
* 最小二乘回归树 缺失属性还没想好怎么处理好
* @author ysh 1208706282
*
*/
public class Cart {
static double MISSINGDATA = -111111111;
int mMaxDepth; //设定的最大深度
int mMinLeaf; //节点最小样本数
double mFeatureRate; //属性选择比率
List<Sample> mSamples;
Random mRandom;
Node mParent; //回归树根节点
static class Sample{
Double label;
List<Double> feature;
}
static class Node{
List<Sample> samples;
int depth;
int featureId;
double splitValue;
double fitness;
double predict;
Node childs[];
boolean leaf;
}
/**
* 加载数据 回归树
* @param path
* @param regex
* @throws Exception
*/
public void loadData(String path,String regex) throws Exception{
mSamples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
sample.label = Double.valueOf(splits[0]);
sample.feature = new ArrayList<Double>(splits.length-1);
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Double(splits[i+1]));
}
mSamples.add(sample);
}
reader.close();
}
public void setData(List<Sample> samples){
this.mSamples = samples;
}
/**
* 加载验证测试集
* @param path
* @param regex
* @throws Exception
*/
public static List<Sample> loadTestData(String path,boolean hasLabel,String regex) throws Exception{
List<Sample> samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
if(hasLabel){
sample.label = Double.valueOf(splits[0]);
}
sample.feature = new ArrayList<Double>(splits.length-1);
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Double(splits[i+1]));
}
samples.add(sample);
}
reader.close();
return samples;
}
/**
* 求节点均值
* @param samples
* @return
*/
double getAverage(Sample []samples){
double avg = 0;
for(Sample sample:samples){
avg += sample.label;
}
return avg/samples.length;
}
/**
* 判断样本标签是否一致
* @param samples
* @return
*/
boolean isSame(Sample []samples){
double label = samples[0].label;
boolean issame = true;
for(Sample sample:samples){
if((label - sample.label)<1E-5){
issame = false;
break;
}
}
return issame;
}
/**
* 指定特征的不纯度 越小越好
* @param samples
* @param featIndex
* @param node
* @return
*/
int getImpurity(Sample[] samples,final int featIndex,Node node){
long start = System.currentTimeMillis();
//System.out.println("getImpurity "+featIndex+" "+start);
int ret = 0;
Arrays.sort(samples, new Comparator<Sample>(){
@Override
public int compare(Sample o1, Sample o2) {
// TODO Auto-generated method stub
int ret = 0;
if(o1.feature.get(featIndex) < o2.feature.get(featIndex)){
ret = -1;
}else{
ret = 1;
}
return ret;
}});
double ts = 0;
double ls = 0;
double rs = 0;
double avgSplit = 0;
for(Sample s:samples){
ts += s.label;
}
ls += samples[0].label;
double bestFitness = Double.MAX_VALUE;
double bestSplit = 0;
for(int i=1;i<samples.length;i++){
ls += samples[i].label;
if((samples[i].feature.get(featIndex)-samples[i-1].feature.get(featIndex))<1E-4){
continue;
}
//System.out.println("getImpurity "+featIndex+" "+(System.currentTimeMillis()-start)/1000);
ls -= samples[i].label;
double lavg = ls/i;
double lerror = 0;
for(int j=0;j<i;j++){
lerror += (samples[j].label-lavg)*(samples[j].label-lavg);
}
double ravg = (ts-ls)/(samples.length-i);
double rerror = 0;
for(int j=i;j<samples.length;j++){
rerror += (samples[j].label-ravg)*(samples[j].label-ravg);
}
if(bestFitness > (lerror+rerror)){
bestFitness = lerror+rerror;
bestSplit = (samples[i].feature.get(featIndex) + samples[i-1].feature.get(featIndex))/2;
}
}
node.fitness = bestFitness;
node.splitValue = bestSplit;
return bestFitness!=Double.MAX_VALUE? 0:1;
}
/**
* 找到最佳切分属性及其分割点
* @param samples
* @param node
* @return
*/
int findSplit(Sample []samples,Node node){
int ret = 0;
int featureIndex[] = new int[samples[0].feature.size()];
for(int i=0;i<samples[0].feature.size();i++){
featureIndex[i] = i;
}
int index = -1;
for(int i=0;i<samples[0].feature.size();i++){
index = mRandom.nextInt(samples[0].feature.size());
featureIndex[i] = featureIndex[i]^featureIndex[index];
featureIndex[index] = featureIndex[i]^featureIndex[index];
featureIndex[i] = featureIndex[i]^featureIndex[index];
}
int bestFeatIdx = 0;
double bestFitness = Double.MAX_VALUE;
double bestSplitValue = 0;
for(int feat=0;feat<featureIndex.length*mFeatureRate;feat++){
int idx = featureIndex[feat];
ret = getImpurity(samples,idx,node);
if(ret != 0){
continue;
}
if(bestFitness > node.fitness){
bestFitness = node.fitness;
bestFeatIdx = idx;
bestSplitValue = node.splitValue;
}
}
node.fitness = bestFitness;
node.featureId = bestFeatIdx;
node.splitValue = bestSplitValue;
return bestFitness!=Double.MAX_VALUE ? 0:1;
}
/**
* 分割数据
* @param samples
* @param node
*/
public void splitData(Sample []samples,Node node){
node.childs = new Node[3];
for(int i=0;i<3;i++){
node.childs[i] = new Node();
node.childs[i].depth = node.depth+1;
node.childs[i].samples = new ArrayList<Sample>();
}
int feat = node.featureId;
for(Sample s:samples){
if(s.feature.get(feat) == Cart.MISSINGDATA){
node.childs[2].samples.add(s);
continue;
}
if(s.feature.get(feat) < node.splitValue){
node.childs[0].samples.add(s);
}else{
node.childs[1].samples.add(s);
}
}
}
/**
* 递归训练创建
* @param samples
* @param node
*/
public void fit(Sample []samples,Node node){
node.predict = getAverage(samples);
if((node.depth==mMaxDepth) || isSame(samples) || samples.length<mMinLeaf){
node.leaf = true;
return;
}
int ret = 0;
ret = findSplit(samples,node);
if(ret != 0){
node.leaf = true;
return;
}
splitData(samples,node);
if(node.childs[0].samples.isEmpty() || node.childs[1].samples.isEmpty()){
node.leaf = true;
return;
}
Sample []s = null;
for(int i=0;i<3;i++){
s = new Sample[node.childs[i].samples.size()];
for(int j=0;j<s.length;j++){
s[j] = node.childs[i].samples.get(j);
}
if(s.length != 0){
fit(s,node.childs[i]);
}
}
}
/**
* 训练
*/
public void train(){
mParent = new Node();
mParent.samples = mSamples;
mParent.depth = 0;
Sample []s = new Sample[mSamples.size()];
for(int i=0;i<s.length;i++){
s[i] = mSamples.get(i);
}
fit(s,mParent);
}
/**
* 分类
* @param sample
* @return
*/
public double classify(Sample sample){
return classify(mParent,sample);
}
/**
* 分类
* @param node
* @param sample
* @return
*/
public double classify(Node node,Sample sample){
if(node.leaf == true){
return node.predict;
}
int fea = node.featureId;
if(sample.feature.get(fea) == Cart.MISSINGDATA){
return classify(node.childs[2],sample);
}
if(sample.feature.get(fea) < node.splitValue){
return classify(node.childs[0],sample);
}else{
return classify(node.childs[1],sample);
}
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
// TODO Auto-generated method stub
Random ran = new Random();
ran.setSeed(10001);
for(int i=0;i<10;i++){
System.out.println(ran.nextInt(10));
}
Cart cart = new Cart();
cart.mFeatureRate = 0.8;
cart.mMaxDepth = 6;
cart.mMinLeaf = 1;
cart.mRandom = new Random();
cart.mRandom.setSeed(100);
cart.loadData("F:/2016-contest/20161001/train_data_1.csv", ",");
System.out.println(System.currentTimeMillis());
cart.train();
List<Sample> samples = cart.loadTestData("F:/2016-contest/20161001/valid_data_1.csv", true, ",");
double sum = 0;
for(Sample s:samples){
double val = cart.classify(s);
sum += (val-s.label)*(val-s.label);
System.out.println(cart.classify(s)+" "+s.label);
}
System.out.println(sum/samples.size());
System.out.println(System.currentTimeMillis());
}
}