darknet分类源码解析

void validate_classifier_multi(char *datacfg, char *filename, char *weightfile)
{
    int i, j;
    network net = parse_network_cfg(filename);
    set_batch_network(&net, 1);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    srand(time(0));

    list *options = read_data_cfg(datacfg);//读.data文件到option列表中

    char *label_list = option_find_str(options, "labels", "data/labels.list");
	//从读到的.data生成的option列表去找对饮的字段如labels,将labels的配置路径放到label_list指针中,
	//然后如果labels的配置路径是"data/labels.list",打印“使用默认配置”字样
    char *valid_list = option_find_str(options, "valid", "data/train.list");// l,key,def;  return  def
    int classes = option_find_int(options, "classes", 2);
    int topk = option_find_int(options, "top", 1);
    if (topk > classes) topk = classes;//找的比类别还多

    char **labels = get_labels(label_list);
	//将labels.list标签名读到lables字符指针,可以通过labels[i]访问标签
    list *plist = get_paths(valid_list);//得到验证集的数据路径
    int scales[] = {224, 288, 320, 352, 384};
    int nscales = sizeof(scales)/sizeof(scales[0]);

    char **paths = (char **)list_to_array(plist);
    int m = plist->size;
    free_list(plist);

    float avg_acc = 0;
    float avg_topk = 0;
    int* indexes = (int*)calloc(topk, sizeof(int));

    for(i = 0; i < m; ++i){
        int class_id = -1;//一般用负数初始化
        char *path = paths[i];//这里的路径名包括文件名之外的路径吗?
        for(j = 0; j < classes; ++j){
            if(strstr(path, labels[j])){
				//在path字符串中查找labels[j]字符串第一次出现的位置
                class_id = j;
				//这里实现了数据集在训练过程中的类别的确定。还是看匹配,只要标签在文件名中
                break;
            }
        }
        float* pred = (float*)calloc(classes, sizeof(float));
        image im = load_image_color(paths[i], 0, 0);
        for(j = 0; j < nscales; ++j){
            image r = resize_min(im, scales[j]);
            resize_network(&net, r.w, r.h);
            float *p = network_predict(net, r.data);
            if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy, 1);
            axpy_cpu(classes, 1, p, 1, pred, 1);
            flip_image(r);
            p = network_predict(net, r.data);
            axpy_cpu(classes, 1, p, 1, pred, 1);
            if(r.data != im.data) free_image(r);
        }
        free_image(im);
        top_k(pred, classes, topk, indexes);
        free(pred);
        if(indexes[0] == class_id) avg_acc += 1;
        for(j = 0; j < topk; ++j){
            if(indexes[j] == class_id) avg_topk += 1;
        }

        printf("%d: top 1: %f, top %d: %f\n", i, avg_acc/(i+1), topk, avg_topk/(i+1));
    }
}

 

void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top)
{//反初始化主要是类对象的析构
    network net = parse_network_cfg_custom(cfgfile, 1, 0);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    set_batch_network(&net, 1);
    srand(2222222);

    fuse_conv_batchnorm(net);
    calculate_binary_weights(net);

    list *options = read_data_cfg(datacfg);

    char *name_list = option_find_str(options, "names", 0);
    if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list");
    int classes = option_find_int(options, "classes", 2);
    if (top == 0) top = option_find_int(options, "top", 1);
    if (top > classes) top = classes;

    int i = 0;
    char **names = get_labels(name_list);
    clock_t time;
    int* indexes = (int*)calloc(top, sizeof(int));
    char buff[256];
    char *input = buff;
    //int size = net.w;
    while(1){
        if(filename){
            strncpy(input, filename, 256);//将filename的前256个字符复制到input中。
        }else{
            printf("Enter Image Path: ");
            fflush(stdout);
            input = fgets(input, 256, stdin);
            if(!input) return;
            strtok(input, "\n");
        }
        image im = load_image_color(input, 0, 0);
        image r = letterbox_image(im, net.w, net.h);
        //image r = resize_min(im, size);
        //resize_network(&net, r.w, r.h);
        printf("%d %d\n", r.w, r.h);

        float *X = r.data;
        time=clock();
        float *predictions = network_predict(net, X);
        if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 0);
        top_k(predictions, net.outputs, top, indexes);
		//按得分来排top k,indexes是新的排序指针,按升序排列,prediction越大的在indexes里面的id越是靠后。
        printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
        for(i = 0; i < top; ++i){
            int index = indexes[i];
			//hierarchy是一个树形结构体指针变量。应该是没有的。
            if(net.hierarchy) printf("%d, %s: %f, parent: %s \n",index, names[index], predictions[index], (net.hierarchy->parent[index] >= 0) ? names[net.hierarchy->parent[index]] : "Root");
            else printf("%s: %f\n",names[index], predictions[index]);
			//names[index]是分类的对应的类别名称如yb,ye,yf
			//predictions[index]是推理置信度
        }
        if(r.data != im.data) free_image(r);
        free_image(im);
        if (filename) break;//可以批量测试,如果filename是False,跳出
    }
}

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值