机器学习-多分类问题

一、背景

为何选择svm来尝试解决多分类问题,以下为决策树和svm用于多分类的表现上的差异:

决策树
1
优化后的决策树(随机森林)
2
svm
3
详见:http://blog.jobbole.com/98635/

二、SVM支持向量机(Support vector machine)

SVM原理:http://en.wikipedia.org/wiki/Support_vector_machine
SVM是两类分类器;对于k类的多分类问题的处理方式如下

  1. 设计k个SVM两类分类器;
  2. 设计两两k(k-1)/2个SVM两类分类器。
  3. 在线性方程后加高阶项:采用一次优化求解解决问题。对于每一类,设计w_i与b_i,约束真实类别对应的w_i x + b_i大于其他类别的w_i x + b_i进行训练,求解目标是所有w_i的范数之和最小,也可以引入 样本数乘以类别数 个松驰变量

三、SVR支持向量回归(Support vector Regression)

原理:用核函数代替线性方程中的线性项可以使原来的线性算法“非线性化”,即能做非线性回归。引进核函数达到了“升维”的目的,而增加的可调参数使得过拟合依然能控制。

  1. 所谓回归(regression),基本上就是拟合,用一个函数拟合x与y的关系。对于SVR来说,x是向量,y是标量,拟合的函数形式为y=W^T*g(x)+b,其中g(x)为核函数对应的特征空间向量。
  2. SVR认为,只要估计的y在实际的y的两侧一个固定的范围(epsilon)之内,就认为是估计正确,没有任何损失;
  3. SVR的优化目标,是|W|最小,这样y-x曲线的斜率最小,这个function最flat,这样据说可以增加估计的鲁棒性。
  4. 之后的事情就很自然了,和SVM一样:可以有soft margin,用一个小正数控制。用对偶式来解;但有一个不同,控制范围的epsilon的值难于确定,在最小优化目标中加入一项C*\nu*\epsilon,其中epsilon是一个变量,nu是一个预先给定的正数。

5

四、svm设置参数

-s svm类型:SVM设置类型(默认0)

0 -- C-SVC
1 --v-SVC
2  一类SVM
3 -- e -SVR
4 -- v-SVR

-t 核函数类型:核函数设置类型(默认2)

0 – 线性:u'v
1 – 多项式:(r*u'v + coef0)^degree
2 – RBF函数:exp(-r|u-v|^2)
3 –sigmoid:tanh(r*u'v + coef0)

-d degree:核函数中的degree设置(默认3)
-g r(gama):核函数中的函数设置(默认1/ k)
-r coef0:核函数中的coef0设置(默认0)
-c cost:设置C-SVC, -SVR和-SVR的参数(默认1)
-n nu:设置SVC,一类SVM和 SVR的参数(默认0.5)
-p e:设置 -SVR 中损失函数的值(默认0.1)
-m cachesize:设置cache内存大小,以MB为单位(默认40)
-e :设置允许的终止判据(默认0.001)
-h shrinking:是否使用启发式,0或1(默认1)
-wi weight:设置第几类的参数C为weightC(C-SVC中的C)(默认1)
-v n: n-fold交互检验模式

五、svm模型参数

  • svm_type:所选择的svm类型,默认为c_svc
  • kernel_type rbf:训练采用的核函数类型,此处为RBF核
  • gamma 0.0078125:RBF核的参数γ
  • nr_class 6:类别数,此处为6分类问题
  • total_sv 18:支持向量总个数
  • rho 0.004423136341674322 -0.02055338568924989 0.03588086612165208 0.24771746047322893 0.00710699773513259 -0.008734834466328766 0.02297409269106355 0.24299467083662166 -0.07400614425237287 -0.0050679463881033344 0.18446534035305884 0.004123018419961004 0.22127259896446397 -0.012677989710344693 -0.2178023679167552 :判决函数的偏置项b
  • label 0 9 99 999 100 101:原始文件中的类别标识
  • nr_sv 2 2 3 3 4 4:每个类的支持向量机的个数
  • SV :以下为各个类的权系数及相应的支持向量

六、分类样例

6

数据集下载
libsvm java包下载

七、参考链接

八、具体代码

public class SvmTest3 {
    public static void main(String[] args) {
        String []arg ={ "trainfile/train1.txt", //存放SVM训练模型用的数据的路径
                "trainfile/model_r.txt"};  //存放SVM通过训练数据训练出来的模型的路径

        String []parg={"trainfile/test2.txt",   //这个是存放测试数据
                "trainfile/model_r.txt",  //调用的是训练以后的模型
                "trainfile/out_r.txt"};  //生成的结果的文件的路径
        System.out.println("........SVM运行开始..........");
        //创建一个训练对象
        SvmTrain t = new SvmTrain();
        //创建一个预测或者分类的对象
        SvmPredict p= new SvmPredict();
        //归一化
        SvmScale svm_scale = new SvmScale();
        try {
            //String[] testArgs = {"-l","0", "-u","1","-s","trainfile/trainscale.txt","trainfile/train.txt"};
            //svm_scale.main(testArgs);
            //String[] argvScaleTest ={"-r","trainfile/trainscale.txt","trainfile/train.txt"};
            //svm_scale.main(argvScaleTest);
            t.main(arg);   //调用
            p.main(parg);  //调用
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

/**归一化调用示例
             * String[] testArgs = {"-l","0", "-u","1","-s","chao-test-scale","UCI-breast-cancer-tra"};
             svm_scale.main(testArgs);
             String[] argvScaleTest ={"-r","chao-test-scale","UCI-breast-cancer-test"};
             svm_scale.main(testArgs);

             svm_scale无直接生成归一化后的文件方法,控制台实现命令 :
             java svm_scale -s chao-test-scale train>train.scale
             java svm_scale -s chao-test-scale test>test.scale
           */
public class SvmTrain {
    private svm_parameter param;        // set by parse_command_line
    private svm_problem prob;       // set by read_problem
    private svm_model model;
    private String input_file_name;     // set by parse_command_line
    private String model_file_name;     // set by parse_command_line
    private String error_msg;
    private int cross_validation;
    private int nr_fold;

    private static svm_print_interface svm_print_null = new svm_print_interface()
    {
        public void print(String s) {}
    };

    private static void exit_with_help()
    {
        System.out.print(
                "Usage: svm_train [options] training_set_file [model_file]\n"
                        +"options:\n"
                        +"-s svm_type : set type of SVM (default 0)\n"
                        +"  0 -- C-SVC      (multi-class classification)\n"
                        +"  1 -- nu-SVC     (multi-class classification)\n"
                        +"  2 -- one-class SVM\n"
                        +"  3 -- epsilon-SVR    (regression)\n"
                        +"  4 -- nu-SVR     (regression)\n"
                        +"-t kernel_type : set type of kernel function (default 2)\n"
                        +"  0 -- linear: u'*v\n"
                        +"  1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
                        +"  2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
                        +"  3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
                        +"  4 -- precomputed kernel (kernel values in training_set_file)\n"
                        +"-d degree : set degree in kernel function (default 3)\n"
                        +"-g gamma : set gamma in kernel function (default 1/num_features)\n"
                        +"-r coef0 : set coef0 in kernel function (default 0)\n"
                        +"-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
                        +"-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
                        +"-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
                        +"-m cachesize : set cache memory size in MB (default 100)\n"
                        +"-e epsilon : set tolerance of termination criterion (default 0.001)\n"
                        +"-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
                        +"-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
                        +"-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
                        +"-v n : n-fold cross validation mode\n"
                        +"-q : quiet mode (no outputs)\n"
        );
        System.exit(1);
    }

    private void do_cross_validation()
    {
        int i;
        int total_correct = 0;
        double total_error = 0;
        double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
        double[] target = new double[prob.l];

        libsvm.svm.svm_cross_validation(prob, param, nr_fold, target);
        if(param.svm_type == svm_parameter.EPSILON_SVR ||
                param.svm_type == svm_parameter.NU_SVR)
        {
            for(i=0;i<prob.l;i++)
            {
                double y = prob.y[i];
                double v = target[i];
                total_error += (v-y)*(v-y);
                sumv += v;
                sumy += y;
                sumvv += v*v;
                sumyy += y*y;
                sumvy += v*y;
            }
            System.out.print("Cross Validation Mean squared error = "+total_error/prob.l+"\n");
            System.out.print("Cross Validation Squared correlation coefficient = "+
                            ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
                                    ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"\n"
            );
        }
        else
        {
            for(i=0;i<prob.l;i++)
                if(target[i] == prob.y[i])
                    ++total_correct;
            System.out.print("Cross Validation Accuracy = "+100.0*total_correct/prob.l+"%\n");
        }
    }

    private void run(String argv[]) throws IOException
    {
        parse_command_line(argv);
        read_problem();
        error_msg = libsvm.svm.svm_check_parameter(prob, param);

        if(error_msg != null)
        {
            System.err.print("ERROR: "+error_msg+"\n");
            System.exit(1);
        }

        if(cross_validation != 0)
        {
            do_cross_validation();
        }
        else
        {
            model = libsvm.svm.svm_train(prob, param);
            libsvm.svm.svm_save_model(model_file_name, model);
        }
    }

    public static void main(String argv[]) throws IOException
    {
        SvmTrain t = new SvmTrain();
        t.run(argv);
    }

    private static double atof(String s)
    {
        double d = Double.valueOf(s).doubleValue();
        if (Double.isNaN(d) || Double.isInfinite(d))
        {
            System.err.print("NaN or Infinity in input\n");
            System.exit(1);
        }
        return(d);
    }

    private static int atoi(String s)
    {
        return Integer.parseInt(s);
    }

    private void parse_command_line(String argv[])
    {
        int i;
        svm_print_interface print_func = null;  // default printing to stdout

        param = new svm_parameter();
        // default values
        //param.svm_type = svm_parameter.C_SVC;
        //param.kernel_type = svm_parameter.RBF;

        //param.svm_type = svm_parameter.NU_SVR;
        //param.kernel_type = svm_parameter.POLY;

        param.svm_type = svm_parameter.C_SVC;
        param.kernel_type = svm_parameter.POLY;
        param.degree = 3;
        param.gamma = 0;    // 1/num_features
        param.coef0 = 0;
        param.nu = 0.5;
        param.cache_size = 100;
        param.C = 1;
        param.eps = 1e-3;
        param.p = 0.1;
        param.shrinking = 1;
        param.probability = 0;
        param.nr_weight = 0;
        param.weight_label = new int[0];
        param.weight = new double[0];
        cross_validation = 0;

        // parse options
        for(i=0;i<argv.length;i++)
        {
            if(argv[i].charAt(0) != '-') break;
            if(++i>=argv.length)
                exit_with_help();
            switch(argv[i-1].charAt(1))
            {
                case 's':
                    param.svm_type = atoi(argv[i]);
                    break;
                case 't':
                    param.kernel_type = atoi(argv[i]);
                    break;
                case 'd':
                    param.degree = atoi(argv[i]);
                    break;
                case 'g':
                    param.gamma = atof(argv[i]);
                    break;
                case 'r':
                    param.coef0 = atof(argv[i]);
                    break;
                case 'n':
                    param.nu = atof(argv[i]);
                    break;
                case 'm':
                    param.cache_size = atof(argv[i]);
                    break;
                case 'c':
                    param.C = atof(argv[i]);
                    break;
                case 'e':
                    param.eps = atof(argv[i]);
                    break;
                case 'p':
                    param.p = atof(argv[i]);
                    break;
                case 'h':
                    param.shrinking = atoi(argv[i]);
                    break;
                case 'b':
                    param.probability = atoi(argv[i]);
                    break;
                case 'q':
                    print_func = svm_print_null;
                    i--;
                    break;
                case 'v':
                    cross_validation = 1;
                    nr_fold = atoi(argv[i]);
                    if(nr_fold < 2)
                    {
                        System.err.print("n-fold cross validation: n must >= 2\n");
                        exit_with_help();
                    }
                    break;
                case 'w':
                    ++param.nr_weight;
                {
                    int[] old = param.weight_label;
                    param.weight_label = new int[param.nr_weight];
                    System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);
                }

                {
                    double[] old = param.weight;
                    param.weight = new double[param.nr_weight];
                    System.arraycopy(old,0,param.weight,0,param.nr_weight-1);
                }

                param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2));
                param.weight[param.nr_weight-1] = atof(argv[i]);
                break;
                default:
                    System.err.print("Unknown option: " + argv[i-1] + "\n");
                    exit_with_help();
            }
        }

        svm.svm_set_print_string_function(print_func);

        // determine filenames

        if(i>=argv.length)
            exit_with_help();

        input_file_name = argv[i];

        if(i<argv.length-1)
            model_file_name = argv[i+1];
        else
        {
            int p = argv[i].lastIndexOf('/');
            ++p;    // whew...
            model_file_name = argv[i].substring(p)+".model";
        }
    }

    // read in a problem (in svmlight format)

    private void read_problem() throws IOException
    {
        BufferedReader fp = new BufferedReader(new FileReader(input_file_name));
        Vector<Double> vy = new Vector<Double>();
        Vector<svm_node[]> vx = new Vector<svm_node[]>();
        int max_index = 0;

        while(true)
        {
            String line = fp.readLine();
            if(line == null) break;

            StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");

            vy.addElement(atof(st.nextToken()));
            int m = st.countTokens()/2;
            svm_node[] x = new svm_node[m];
            for(int j=0;j<m;j++)
            {
                x[j] = new svm_node();
                x[j].index = atoi(st.nextToken());
                x[j].value = atof(st.nextToken());
            }
            if(m>0) max_index = Math.max(max_index, x[m-1].index);
            vx.addElement(x);
        }

        prob = new svm_problem();
        prob.l = vy.size();
        prob.x = new svm_node[prob.l][];
        for(int i=0;i<prob.l;i++)
            prob.x[i] = vx.elementAt(i);
        prob.y = new double[prob.l];
        for(int i=0;i<prob.l;i++)
            prob.y[i] = vy.elementAt(i);

        if(param.gamma == 0 && max_index > 0)
            param.gamma = 1.0/max_index;

        if(param.kernel_type == svm_parameter.PRECOMPUTED)
            for(int i=0;i<prob.l;i++)
            {
                if (prob.x[i][0].index != 0)
                {
                    System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n");
                    System.exit(1);
                }
                if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
                {
                    System.err.print("Wrong input format: sample_serial_number out of range\n");
                    System.exit(1);
                }
            }

        fp.close();
    }
}
public class SvmPredict {
    private static svm_print_interface svm_print_null = new svm_print_interface()
    {
        public void print(String s) {}
    };

    private static svm_print_interface svm_print_stdout = new svm_print_interface()
    {
        public void print(String s)
        {
            System.out.print(s);
        }
    };

    private static svm_print_interface svm_print_string = svm_print_stdout;

    static void info(String s)
    {
        svm_print_string.print(s);
    }

    private static double atof(String s)
    {
        return Double.valueOf(s).doubleValue();
    }

    private static int atoi(String s)
    {
        return Integer.parseInt(s);
    }

    private static void predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability) throws IOException
    {
        int correct = 0;
        int total = 0;
        double error = 0;
        double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;

        int svm_type= libsvm.svm.svm_get_svm_type(model);
        int nr_class= libsvm.svm.svm_get_nr_class(model);
        double[] prob_estimates=null;

        if(predict_probability == 1)
        {
            if(svm_type == svm_parameter.EPSILON_SVR ||
                    svm_type == svm_parameter.NU_SVR)
            {
                SvmPredict.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+ libsvm.svm.svm_get_svr_probability(model)+"\n");
            }
            else
            {
                int[] labels=new int[nr_class];
                libsvm.svm.svm_get_labels(model, labels);
                prob_estimates = new double[nr_class];
                output.writeBytes("labels");
                for(int j=0;j<nr_class;j++)
                    output.writeBytes(" "+labels[j]);
                output.writeBytes("\n");
            }
        }
        while(true)
        {
            String line = input.readLine();
            if(line == null) break;

            StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");

            double target = atof(st.nextToken());
            int m = st.countTokens()/2;
            svm_node[] x = new svm_node[m];
            for(int j=0;j<m;j++)
            {
                x[j] = new svm_node();
                x[j].index = atoi(st.nextToken());
                x[j].value = atof(st.nextToken());
            }

            double v;
            if (predict_probability==1 && (svm_type==svm_parameter.C_SVC || svm_type==svm_parameter.NU_SVC))
            {
                v = libsvm.svm.svm_predict_probability(model, x, prob_estimates);
                output.writeBytes(v+" ");
                for(int j=0;j<nr_class;j++)
                    output.writeBytes(prob_estimates[j]+" ");
                output.writeBytes("\n");
            }
            else
            {
                v = libsvm.svm.svm_predict(model, x);
                output.writeBytes(v+"\n");
            }

            if(v == target)
                ++correct;
            error += (v-target)*(v-target);
            sumv += v;
            sumy += target;
            sumvv += v*v;
            sumyy += target*target;
            sumvy += v*target;
            ++total;
        }
        if(svm_type == svm_parameter.EPSILON_SVR ||
                svm_type == svm_parameter.NU_SVR)
        {
            SvmPredict.info("Mean squared error = "+error/total+" (regression)\n");
            SvmPredict.info("Squared correlation coefficient = "+
                    ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
                            ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+
                    " (regression)\n");
        }
        else
            SvmPredict.info("Accuracy = "+(double)correct/total*100+
                    "% ("+correct+"/"+total+") (classification)\n");
    }

    private static void exit_with_help()
    {
        System.err.print("usage: svm_predict [options] test_file model_file output_file\n"
                +"options:\n"
                +"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
                +"-q : quiet mode (no outputs)\n");
        System.exit(1);
    }

    public static void main(String argv[]) throws IOException
    {
        int i, predict_probability=0;
        svm_print_string = svm_print_stdout;

        // parse options
        for(i=0;i<argv.length;i++)
        {
            if(argv[i].charAt(0) != '-') break;
            ++i;
            switch(argv[i-1].charAt(1))
            {
                case 'b':
                    predict_probability = atoi(argv[i]);
                    break;
                case 'q':
                    svm_print_string = svm_print_null;
                    i--;
                    break;
                default:
                    System.err.print("Unknown option: " + argv[i-1] + "\n");
                    exit_with_help();
            }
        }
        if(i>=argv.length-2)
            exit_with_help();
        try
        {
            BufferedReader input = new BufferedReader(new FileReader(argv[i]));
            DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(argv[i+2])));
            svm_model model = libsvm.svm.svm_load_model(argv[i + 1]);
            if (model == null)
            {
                System.err.print("can't open model file "+argv[i+1]+"\n");
                System.exit(1);
            }
            if(predict_probability == 1)
            {
                if(libsvm.svm.svm_check_probability_model(model)==0)
                {
                    System.err.print("Model does not support probabiliy estimates\n");
                    System.exit(1);
                }
            }
            else
            {
                if(svm.svm_check_probability_model(model)!=0)
                {
                    SvmPredict.info("Model supports probability estimates, but disabled in prediction.\n");
                }
            }
            predict(input,output,model,predict_probability);
            input.close();
            output.close();
        }
        catch(FileNotFoundException e)
        {
            exit_with_help();
        }
        catch(ArrayIndexOutOfBoundsException e)
        {
            exit_with_help();
        }
    }
}
public class SvmScale {
    private String line = null;
    private double lower = -1.0;
    private double upper = 1.0;
    private double y_lower;
    private double y_upper;
    private boolean y_scaling = false;
    private double[] feature_max;
    private double[] feature_min;
    private double y_max = -Double.MAX_VALUE;
    private double y_min = Double.MAX_VALUE;
    private int max_index;
    private long num_nonzeros = 0;
    private long new_num_nonzeros = 0;

    private static void exit_with_help()
    {
        System.out.print(
                "Usage: svm-scale [options] data_filename\n"
                        +"options:\n"
                        +"-l lower : x scaling lower limit (default -1)\n"
                        +"-u upper : x scaling upper limit (default +1)\n"
                        +"-y y_lower y_upper : y scaling limits (default: no y scaling)\n"
                        +"-s save_filename : save scaling parameters to save_filename\n"
                        +"-r restore_filename : restore scaling parameters from restore_filename\n"
        );
        System.exit(1);
    }

    private BufferedReader rewind(BufferedReader fp, String filename) throws IOException
    {
        fp.close();
        return new BufferedReader(new FileReader(filename));
    }

    private void output_target(double value)
    {
        if(y_scaling)
        {
            if(value == y_min)
                value = y_lower;
            else if(value == y_max)
                value = y_upper;
            else
                value = y_lower + (y_upper-y_lower) *
                        (value-y_min) / (y_max-y_min);
        }

        System.out.print(value + " ");
    }

    private void output(int index, double value)
    {
        /* skip single-valued attribute */
        if(feature_max[index] == feature_min[index])
            return;

        if(value == feature_min[index])
            value = lower;
        else if(value == feature_max[index])
            value = upper;
        else
            value = lower + (upper-lower) *
                    (value-feature_min[index])/
                    (feature_max[index]-feature_min[index]);

        if(value != 0)
        {
            System.out.print(index + ":" + value + " ");
            new_num_nonzeros++;
        }
    }

    private String readline(BufferedReader fp) throws IOException
    {
        line = fp.readLine();
        return line;
    }

    private void run(String []argv) throws IOException
    {
        int i,index;
        BufferedReader fp = null, fp_restore = null;
        String save_filename = null;
        String restore_filename = null;
        String data_filename = null;


        for(i=0;i<argv.length;i++)
        {
            if (argv[i].charAt(0) != '-')   break;
            ++i;
            switch(argv[i-1].charAt(1))
            {
                case 'l': lower = Double.parseDouble(argv[i]);  break;
                case 'u': upper = Double.parseDouble(argv[i]);  break;
                case 'y':
                    y_lower = Double.parseDouble(argv[i]);
                    ++i;
                    y_upper = Double.parseDouble(argv[i]);
                    y_scaling = true;
                    break;
                case 's': save_filename = argv[i];  break;
                case 'r': restore_filename = argv[i];   break;
                default:
                    System.err.println("unknown option");
                    exit_with_help();
            }
        }

        if(!(upper > lower) || (y_scaling && !(y_upper > y_lower)))
        {
            System.err.println("inconsistent lower/upper specification");
            System.exit(1);
        }
        if(restore_filename != null && save_filename != null)
        {
            System.err.println("cannot use -r and -s simultaneously");
            System.exit(1);
        }

        if(argv.length != i+1) // modified by yehui
       // if(argv.length != i)
            exit_with_help();

        data_filename = argv[i];// modified by yehui
        //data_filename = argv[i-1];
        try {
            fp = new BufferedReader(new FileReader(data_filename));
        } catch (Exception e) {
            System.err.println("can't open file " + data_filename);
            System.exit(1);
        }

        /* assumption: min index of attributes is 1 */
        /* pass 1: find out max index of attributes */
        max_index = 0;

        if(restore_filename != null)
        {
            int idx, c;

            try {
                fp_restore = new BufferedReader(new FileReader(restore_filename));
            }
            catch (Exception e) {
                System.err.println("can't open file " + restore_filename);
                System.exit(1);
            }
            if((c = fp_restore.read()) == 'y')
            {
                fp_restore.readLine();
                fp_restore.readLine();
                fp_restore.readLine();
            }
            fp_restore.readLine();
            fp_restore.readLine();

            String restore_line = null;
            while((restore_line = fp_restore.readLine())!=null)
            {
                StringTokenizer st2 = new StringTokenizer(restore_line);
                idx = Integer.parseInt(st2.nextToken());
                max_index = Math.max(max_index, idx);
            }
            fp_restore = rewind(fp_restore, restore_filename);
        }

        while (readline(fp) != null)
        {
            StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
            st.nextToken();
            while(st.hasMoreTokens())
            {
                try {
                    index = Integer.parseInt(st.nextToken());
                    max_index = Math.max(max_index, index);
                    st.nextToken();
                    num_nonzeros++;
                } catch (NumberFormatException e){
                    System.out.println(e);
                }
            }
        }

        try {
            feature_max = new double[(max_index+1)];
            feature_min = new double[(max_index+1)];
        } catch(OutOfMemoryError e) {
            System.err.println("can't allocate enough memory");
            System.exit(1);
        }

        for(i=0;i<=max_index;i++)
        {
            feature_max[i] = -Double.MAX_VALUE;
            feature_min[i] = Double.MAX_VALUE;
        }

        fp = rewind(fp, data_filename);

        /* pass 2: find out min/max value */
        while(readline(fp) != null)
        {
            int next_index = 1;
            double target;
            double value;

            StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
            target = Double.parseDouble(st.nextToken());
            y_max = Math.max(y_max, target);
            y_min = Math.min(y_min, target);

            while (st.hasMoreTokens())
            {
                index = Integer.parseInt(st.nextToken());
                value = Double.parseDouble(st.nextToken());

                for (i = next_index; i<index; i++)
                {
                    feature_max[i] = Math.max(feature_max[i], 0);
                    feature_min[i] = Math.min(feature_min[i], 0);
                }

                feature_max[index] = Math.max(feature_max[index], value);
                feature_min[index] = Math.min(feature_min[index], value);
                next_index = index + 1;
            }

            for(i=next_index;i<=max_index;i++)
            {
                feature_max[i] = Math.max(feature_max[i], 0);
                feature_min[i] = Math.min(feature_min[i], 0);
            }
        }

        fp = rewind(fp, data_filename);

        /* pass 2.5: save/restore feature_min/feature_max */
        if(restore_filename != null)
        {
            // fp_restore rewinded in finding max_index
            int idx, c;
            double fmin, fmax;

            fp_restore.mark(2);             // for reset
            if((c = fp_restore.read()) == 'y')
            {
                fp_restore.readLine();      // pass the '\n' after 'y'
                StringTokenizer st = new StringTokenizer(fp_restore.readLine());
                y_lower = Double.parseDouble(st.nextToken());
                y_upper = Double.parseDouble(st.nextToken());
                st = new StringTokenizer(fp_restore.readLine());
                y_min = Double.parseDouble(st.nextToken());
                y_max = Double.parseDouble(st.nextToken());
                y_scaling = true;
            }
            else
                fp_restore.reset();

            if(fp_restore.read() == 'x') {
                fp_restore.readLine();      // pass the '\n' after 'x'
                StringTokenizer st = new StringTokenizer(fp_restore.readLine());
                lower = Double.parseDouble(st.nextToken());
                upper = Double.parseDouble(st.nextToken());
                String restore_line = null;
                while((restore_line = fp_restore.readLine())!=null)
                {
                    StringTokenizer st2 = new StringTokenizer(restore_line);
                    idx = Integer.parseInt(st2.nextToken());
                    fmin = Double.parseDouble(st2.nextToken());
                    fmax = Double.parseDouble(st2.nextToken());
                    if (idx <= max_index)
                    {
                        feature_min[idx] = fmin;
                        feature_max[idx] = fmax;
                    }
                }
            }
            fp_restore.close();
        }

        if(save_filename != null)
        {
            Formatter formatter = new Formatter(new StringBuilder());
            BufferedWriter fp_save = null;

            try {
                fp_save = new BufferedWriter(new FileWriter(save_filename));
            } catch(IOException e) {
                System.err.println("can't open file " + save_filename);
                System.exit(1);
            }

            if(y_scaling)
            {
                formatter.format("y\n");
                formatter.format("%.16g %.16g\n", y_lower, y_upper);
                formatter.format("%.16g %.16g\n", y_min, y_max);
            }
            formatter.format("x\n");
            formatter.format("%.16g %.16g\n", lower, upper);
            for(i=1;i<=max_index;i++)
            {
                if(feature_min[i] != feature_max[i])
                    formatter.format("%d %.16g %.16g\n", i, feature_min[i], feature_max[i]);
            }
            fp_save.write(formatter.toString());
            fp_save.close();
        }

        /* pass 3: scale */
        while(readline(fp) != null)
        {
            int next_index = 1;
            double target;
            double value;

            StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
            target = Double.parseDouble(st.nextToken());
            output_target(target);
            while(st.hasMoreElements())
            {
                index = Integer.parseInt(st.nextToken());
                value = Double.parseDouble(st.nextToken());
                for (i = next_index; i<index; i++)
                    output(i, 0);
                output(index, value);
                next_index = index + 1;
            }

            for(i=next_index;i<= max_index;i++)
                output(i, 0);
            System.out.print("\n");
        }
        if (new_num_nonzeros > num_nonzeros)
            System.err.print(
                    "WARNING: original #nonzeros " + num_nonzeros+"\n"
                            +"         new      #nonzeros " + new_num_nonzeros+"\n"
                            +"Use -l 0 if many original feature values are zeros\n");

        fp.close();
    }

    public static void main(String argv[]) throws IOException
    {
        SvmScale s = new SvmScale();
        s.run(argv);
    }
}
public class SvmToy extends Applet {
    static final String DEFAULT_PARAM="-t 2 -c 100";
    int XLEN;
    int YLEN;

    // off-screen buffer

    Image buffer;
    Graphics buffer_gc;

    // pre-allocated colors

    final static Color colors[] =
            {
                    new Color(0,0,0),
                    new Color(0,120,120),
                    new Color(120,120,0),
                    new Color(120,0,120),
                    new Color(0,200,200),
                    new Color(200,200,0),
                    new Color(200,0,200)
            };

    class point {
        point(double x, double y, byte value)
        {
            this.x = x;
            this.y = y;
            this.value = value;
        }
        double x, y;
        byte value;
    }

    Vector<point> point_list = new Vector<point>();
    byte current_value = 1;

    public void init()
    {
        setSize(getSize());

        final Button button_change = new Button("Change");
        Button button_run = new Button("Run");
        Button button_clear = new Button("Clear");
        Button button_save = new Button("Save");
        Button button_load = new Button("Load");
        final TextField input_line = new TextField(DEFAULT_PARAM);

        BorderLayout layout = new BorderLayout();
        this.setLayout(layout);

        Panel p = new Panel();
        GridBagLayout gridbag = new GridBagLayout();
        p.setLayout(gridbag);

        GridBagConstraints c = new GridBagConstraints();
        c.fill = GridBagConstraints.HORIZONTAL;
        c.weightx = 1;
        c.gridwidth = 1;
        gridbag.setConstraints(button_change,c);
        gridbag.setConstraints(button_run,c);
        gridbag.setConstraints(button_clear,c);
        gridbag.setConstraints(button_save,c);
        gridbag.setConstraints(button_load,c);
        c.weightx = 5;
        c.gridwidth = 5;
        gridbag.setConstraints(input_line,c);

        button_change.setBackground(colors[current_value]);

        p.add(button_change);
        p.add(button_run);
        p.add(button_clear);
        p.add(button_save);
        p.add(button_load);
        p.add(input_line);
        this.add(p,BorderLayout.SOUTH);

        button_change.addActionListener(new ActionListener()
        { public void actionPerformed (ActionEvent e)
            { button_change_clicked(); button_change.setBackground(colors[current_value]); }});

        button_run.addActionListener(new ActionListener()
        { public void actionPerformed (ActionEvent e)
            { button_run_clicked(input_line.getText()); }});

        button_clear.addActionListener(new ActionListener()
        { public void actionPerformed (ActionEvent e)
            { button_clear_clicked(); }});

        button_save.addActionListener(new ActionListener()
        { public void actionPerformed (ActionEvent e)
            { button_save_clicked(input_line.getText()); }});

        button_load.addActionListener(new ActionListener()
        { public void actionPerformed (ActionEvent e)
            { button_load_clicked(); }});

        input_line.addActionListener(new ActionListener()
        { public void actionPerformed (ActionEvent e)
            { button_run_clicked(input_line.getText()); }});

        this.enableEvents(AWTEvent.MOUSE_EVENT_MASK);
    }

    void draw_point(point p)
    {
        Color c = colors[p.value+3];

        Graphics window_gc = getGraphics();
        buffer_gc.setColor(c);
        buffer_gc.fillRect((int)(p.x*XLEN),(int)(p.y*YLEN),4,4);
        window_gc.setColor(c);
        window_gc.fillRect((int)(p.x*XLEN),(int)(p.y*YLEN),4,4);
    }

    void clear_all()
    {
        point_list.removeAllElements();
        if(buffer != null)
        {
            buffer_gc.setColor(colors[0]);
            buffer_gc.fillRect(0,0,XLEN,YLEN);
        }
        repaint();
    }

    void draw_all_points()
    {
        int n = point_list.size();
        for(int i=0;i<n;i++)
            draw_point(point_list.elementAt(i));
    }

    void button_change_clicked()
    {
        ++current_value;
        if(current_value > 3) current_value = 1;
    }

    private static double atof(String s)
    {
        return Double.valueOf(s).doubleValue();
    }

    private static int atoi(String s)
    {
        return Integer.parseInt(s);
    }

    void button_run_clicked(String args)
    {
        // guard
        if(point_list.isEmpty()) return;

        svm_parameter param = new svm_parameter();

        // default values
        //param.svm_type = svm_parameter.C_SVC;
        param.svm_type = svm_parameter.POLY;
        param.kernel_type = svm_parameter.RBF;
        param.degree = 3;
        param.gamma = 0;
        param.coef0 = 0;
        param.nu = 0.5;
        param.cache_size = 40;
        param.C = 1;
        param.eps = 1e-3;
        param.p = 0.1;
        param.shrinking = 1;
        param.probability = 0;
        param.nr_weight = 0;
        param.weight_label = new int[0];
        param.weight = new double[0];

        // parse options
        StringTokenizer st = new StringTokenizer(args);
        String[] argv = new String[st.countTokens()];
        for(int i=0;i<argv.length;i++)
            argv[i] = st.nextToken();

        for(int i=0;i<argv.length;i++)
        {
            if(argv[i].charAt(0) != '-') break;
            if(++i>=argv.length)
            {
                System.err.print("unknown option\n");
                break;
            }
            switch(argv[i-1].charAt(1))
            {
                case 's':
                    param.svm_type = atoi(argv[i]);
                    break;
                case 't':
                    param.kernel_type = atoi(argv[i]);
                    break;
                case 'd':
                    param.degree = atoi(argv[i]);
                    break;
                case 'g':
                    param.gamma = atof(argv[i]);
                    break;
                case 'r':
                    param.coef0 = atof(argv[i]);
                    break;
                case 'n':
                    param.nu = atof(argv[i]);
                    break;
                case 'm':
                    param.cache_size = atof(argv[i]);
                    break;
                case 'c':
                    param.C = atof(argv[i]);
                    break;
                case 'e':
                    param.eps = atof(argv[i]);
                    break;
                case 'p':
                    param.p = atof(argv[i]);
                    break;
                case 'h':
                    param.shrinking = atoi(argv[i]);
                    break;
                case 'b':
                    param.probability = atoi(argv[i]);
                    break;
                case 'w':
                    ++param.nr_weight;
                {
                    int[] old = param.weight_label;
                    param.weight_label = new int[param.nr_weight];
                    System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);
                }

                {
                    double[] old = param.weight;
                    param.weight = new double[param.nr_weight];
                    System.arraycopy(old,0,param.weight,0,param.nr_weight-1);
                }

                param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2));
                param.weight[param.nr_weight-1] = atof(argv[i]);
                break;
                default:
                    System.err.print("unknown option\n");
            }
        }

        // build problem
        svm_problem prob = new svm_problem();
        prob.l = point_list.size();
        prob.y = new double[prob.l];

        if(param.kernel_type == svm_parameter.PRECOMPUTED)
        {
        }
        else if(param.svm_type == svm_parameter.EPSILON_SVR ||
                param.svm_type == svm_parameter.NU_SVR)
        {
            if(param.gamma == 0) param.gamma = 1;
            prob.x = new svm_node[prob.l][1];
            for(int i=0;i<prob.l;i++)
            {
                point p = point_list.elementAt(i);
                prob.x[i][0] = new svm_node();
                prob.x[i][0].index = 1;
                prob.x[i][0].value = p.x;
                prob.y[i] = p.y;
            }

            // build model & classify
            svm_model model = svm.svm_train(prob, param);
            svm_node[] x = new svm_node[1];
            x[0] = new svm_node();
            x[0].index = 1;
            int[] j = new int[XLEN];

            Graphics window_gc = getGraphics();
            for (int i = 0; i < XLEN; i++)
            {
                x[0].value = (double) i / XLEN;
                j[i] = (int)(YLEN*svm.svm_predict(model, x));
            }

            buffer_gc.setColor(colors[0]);
            buffer_gc.drawLine(0,0,0,YLEN-1);
            window_gc.setColor(colors[0]);
            window_gc.drawLine(0,0,0,YLEN-1);

            int p = (int)(param.p * YLEN);
            for(int i=1;i<XLEN;i++)
            {
                buffer_gc.setColor(colors[0]);
                buffer_gc.drawLine(i,0,i,YLEN-1);
                window_gc.setColor(colors[0]);
                window_gc.drawLine(i,0,i,YLEN-1);

                buffer_gc.setColor(colors[5]);
                window_gc.setColor(colors[5]);
                buffer_gc.drawLine(i-1,j[i-1],i,j[i]);
                window_gc.drawLine(i-1,j[i-1],i,j[i]);

                if(param.svm_type == svm_parameter.EPSILON_SVR)
                {
                    buffer_gc.setColor(colors[2]);
                    window_gc.setColor(colors[2]);
                    buffer_gc.drawLine(i-1,j[i-1]+p,i,j[i]+p);
                    window_gc.drawLine(i-1,j[i-1]+p,i,j[i]+p);

                    buffer_gc.setColor(colors[2]);
                    window_gc.setColor(colors[2]);
                    buffer_gc.drawLine(i-1,j[i-1]-p,i,j[i]-p);
                    window_gc.drawLine(i-1,j[i-1]-p,i,j[i]-p);
                }
            }
        }
        else
        {
            if(param.gamma == 0) param.gamma = 0.5;
            prob.x = new svm_node [prob.l][2];
            for(int i=0;i<prob.l;i++)
            {
                point p = point_list.elementAt(i);
                prob.x[i][0] = new svm_node();
                prob.x[i][0].index = 1;
                prob.x[i][0].value = p.x;
                prob.x[i][1] = new svm_node();
                prob.x[i][1].index = 2;
                prob.x[i][1].value = p.y;
                prob.y[i] = p.value;
            }

            // build model & classify
            svm_model model = svm.svm_train(prob, param);
            svm_node[] x = new svm_node[2];
            x[0] = new svm_node();
            x[1] = new svm_node();
            x[0].index = 1;
            x[1].index = 2;

            Graphics window_gc = getGraphics();
            for (int i = 0; i < XLEN; i++)
                for (int j = 0; j < YLEN ; j++) {
                    x[0].value = (double) i / XLEN;
                    x[1].value = (double) j / YLEN;
                    double d = svm.svm_predict(model, x);
                    if (param.svm_type == svm_parameter.ONE_CLASS && d<0) d=2;
                    buffer_gc.setColor(colors[(int)d]);
                    window_gc.setColor(colors[(int)d]);
                    buffer_gc.drawLine(i,j,i,j);
                    window_gc.drawLine(i,j,i,j);
                }
        }

        draw_all_points();
    }

    void button_clear_clicked()
    {
        clear_all();
    }

    void button_save_clicked(String args)
    {
        FileDialog dialog = new FileDialog(new Frame(),"Save",FileDialog.SAVE);
        dialog.setVisible(true);
        String filename = dialog.getDirectory() + dialog.getFile();
        if (filename == null) return;
        try {
            DataOutputStream fp = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filename)));

            int svm_type = svm_parameter.C_SVC;
            int svm_type_idx = args.indexOf("-s ");
            if(svm_type_idx != -1)
            {
                StringTokenizer svm_str_st = new StringTokenizer(args.substring(svm_type_idx+2).trim());
                svm_type = atoi(svm_str_st.nextToken());
            }

            int n = point_list.size();
            if(svm_type == svm_parameter.EPSILON_SVR || svm_type == svm_parameter.NU_SVR)
            {
                for(int i=0;i<n;i++)
                {
                    point p = point_list.elementAt(i);
                    fp.writeBytes(p.y+" 1:"+p.x+"\n");
                }
            }
            else
            {
                for(int i=0;i<n;i++)
                {
                    point p = point_list.elementAt(i);
                    fp.writeBytes(p.value+" 1:"+p.x+" 2:"+p.y+"\n");
                }
            }
            fp.close();
        } catch (IOException e) { System.err.print(e); }
    }

    void button_load_clicked()
    {
        FileDialog dialog = new FileDialog(new Frame(),"Load",FileDialog.LOAD);
        dialog.setVisible(true);
        String filename = dialog.getDirectory() + dialog.getFile();
        if (filename == null) return;
        clear_all();
        try {
            BufferedReader fp = new BufferedReader(new FileReader(filename));
            String line;
            while((line = fp.readLine()) != null)
            {
                StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
                if(st.countTokens() == 5)
                {
                    byte value = (byte)atoi(st.nextToken());
                    st.nextToken();
                    double x = atof(st.nextToken());
                    st.nextToken();
                    double y = atof(st.nextToken());
                    point_list.addElement(new point(x,y,value));
                }
                else if(st.countTokens() == 3)
                {
                    double y = atof(st.nextToken());
                    st.nextToken();
                    double x = atof(st.nextToken());
                    point_list.addElement(new point(x,y,current_value));
                }else
                    break;
            }
            fp.close();
        } catch (IOException e) { System.err.print(e); }
        draw_all_points();
    }

    protected void processMouseEvent(MouseEvent e)
    {
        if(e.getID() == MouseEvent.MOUSE_PRESSED)
        {
            if(e.getX() >= XLEN || e.getY() >= YLEN) return;
            point p = new point((double)e.getX()/XLEN,
                    (double)e.getY()/YLEN,
                    current_value);
            point_list.addElement(p);
            draw_point(p);
        }
    }

    public void paint(Graphics g)
    {
        // create buffer first time
        if(buffer == null) {
            buffer = this.createImage(XLEN,YLEN);
            buffer_gc = buffer.getGraphics();
            buffer_gc.setColor(colors[0]);
            buffer_gc.fillRect(0,0,XLEN,YLEN);
        }
        g.drawImage(buffer,0,0,this);
    }

    public Dimension getPreferredSize() { return new Dimension(XLEN,YLEN+50); }

    public void setSize(Dimension d) { setSize(d.width,d.height); }
    public void setSize(int w,int h) {
        super.setSize(w,h);
        XLEN = w;
        YLEN = h-50;
        clear_all();
    }

    public static void main(String[] argv)
    {
        new AppletFrame("svm_toy",new SvmToy(),500,500+50);
    }
}
public class AppletFrame extends Frame {
    AppletFrame(String title, Applet applet, int width, int height)
    {
        super(title);
        this.addWindowListener(new WindowAdapter() {
            public void windowClosing(WindowEvent e) {
                System.exit(0);
            }
        });
        applet.init();
        applet.setSize(width,height);
        applet.start();
        this.add(applet);
        this.pack();
        this.setVisible(true);
    }
}
展开阅读全文

没有更多推荐了,返回首页