共涉及3个文件: Svm-train.c, Svm.cpp, Svm.h. 建议使用Source Insight软件对这3个文件建立工程. 方便代码阅读. 下面从Svm-train.c文件中的main()函数切入.
- int main(int argc, <span
class="keyword">char **argv)
- {
- char input_file_name[1024]; //训练样本文件名
- char model_file_name[1024]; //输出模型的文件名
- const char *error_msg;
- parse_command_line(argc, argv, input_file_name, model_file_name); //解析运行程序时,命令行输入的参数
- read_problem(input_file_name); //读入训练样本,存入到struct svm_problem prob结构体中
- error_msg = svm_check_parameter(&prob,¶m); //检查训练样本数据格式是否正确
- if(error_msg)
- {
- fprintf(stderr,“ERROR: %s\n”,error_msg);
- exit(1);
- }
- if(cross_validation)
- {
- do_cross_validation(); //根据设置进行交叉验证训练
- }
- else
- {
- model = svm_train(&prob,¶m); //根据问题数据(&prob)和参数(¶m)训练模型
- if(svm_save_model(model_file_name,model))//保存模型到输出
文件中
- {
- fprintf(stderr, “can’t save model to file %s\n”, model_file_name);
- exit(1);
- }
- svm_free_and_destroy_model(&model); //释放模型结构空间
- }
- svm_destroy_param(¶m); //释放使用的其他结构空间
- free(prob.y);
- free(prob.x);
- free(x_space);
- free(line);
- return 0;
- }
下面分析一下main()函数中调用的主要函数程序, 命令行参数解析函数parse_command_line()代码及其注释如下:
- void parse_command_line(int argc, <span
class="keyword">char **argv, char *input_file_name,char
*model_file_name)
- {
- int i;
- void (*print_func)(const <span
class="keyword">char*) = NULL; // default printing to stdout
- // default values
- param.svm_type = C_SVC;
- param.kernel_type = RBF;
- param.degree = 3;
- param.gamma = 0; // 1/num_features
- param.coef0 = 0;
- param.nu = 0.5;
- param.cache_size = 100;
- param.C = 1;
- param.eps = 1e-3;
- param.p = 0.1;
- param.shrinking = 1;
- param.probability = 0;
- param.nr_weight = 0;
- param.weight_label = NULL;
- param.weight = NULL;
- cross_validation = 0;
- // parse options
- for(i=1;i<argc;i++) //argc中存放的是命令行程序运行时的参数
个数
- {
- if(argv[i][0] != ‘-’) break; <span
class="comment">//开头处是否为参数类型标识,若不是跳出循环
- if(++i>=argc) //判断参数类型后是否有其他参数,如样本文件名
- exit_with_help(); //如果没有则退出并打印帮助提示
- switch(argv[i-1][1]) //根据参数标识,转换参数值为正确类型或相应设置
- {
- case ‘s’:
- param.svm_type = atoi(argv[i]);
- break;
- case ‘t’:
- param.kernel_type = atoi(argv[i]);
- break;
- case ‘d’:
- param.degree = atoi(argv[i]);
- break;
- case ‘g’:
- param.gamma = atof(argv[i]);
- break;
- case ‘r’:
- param.coef0 = atof(argv[i]);
- break;
- case ‘n’:
- param.nu = atof(argv[i]);
- break;
- case ‘m’:
- param.cache_size = atof(argv[i]);
- break;
- case ‘c’:
- param.C = atof(argv[i]);
- break;
- case ‘e’:
- param.eps = atof(argv[i]);
- break;
- case ‘p’:
- param.p = atof(argv[i]);
- break;
- case ‘h’:
- param.shrinking = atoi(argv[i]);
- break;
- case ‘b’:
- param.probability = atoi(argv[i]);
- break;
- case ‘q’:
- print_func = &print_null;
- i–;
- break;
- case ‘v’: //设置交叉验证的参数标识
- cross_validation = 1;
- nr_fold = atoi(argv[i]);
- if(nr_fold < 2)
- {
- fprintf(stderr,“n-fold cross validation: n must >= 2\n”);
- exit_with_help();
- }
- break;
- case ‘w’:
- ++param.nr_weight;
- param.weight_label = (int*)realloc(param.weight_label,<span
class="keyword">sizeof(int)*param.nr_weight);
- param.weight = (double*)realloc(param.weight,<span
class="keyword">sizeof(double)*param.nr_weight);
- param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
- param.weight[param.nr_weight-1] = atof(argv[i]);
- break;
- default:
- fprintf(stderr,“Unknown option: -%c\n”, argv[i-1][1]);
- exit_with_help();
- }
- }
- svm_set_print_string_function(print_func);
- // determine filenames
- if(i>=argc)
- exit_with_help();
- strcpy(input_file_name, argv[i]); //将命令行中的训练文件名,赋值给main中的字符数组.
- if(i<argc-1) //如果自定义了输出模型名,则赋值给变量,否则使用默认命名方式
构造文件名
- strcpy(model_file_name,argv[i+1]);
- else
- {
- char *p = strrchr(argv[i],’/');
- if(p==NULL)
- p = argv[i];
- else
- ++p;
- sprintf(model_file_name,“%s.model”,p);
- }
- }