darknet源码学习(一) 入口及测试总体流程

        最近在使用YOLO fastest训练检测模型,之前听过很多次YOLO这个系列,跑过相关测试demo及现成的训练代码,但对里面的细节几乎一无所知。这次在对自己的数据格式进行转换调整后,进行训练,发现检测效果及性能都很好,在iPhone6s及以上机型可以轻松做到实时。但在训练的过程中,日志的输出却存在一些异常的地方。后来跟一些技术群里面的同学讨论过,应该是代码部分存在点小问题,所以需要研究下代码的细节。

        darknet的主函数在darknet.c文件中,入口main函数根据命令行输入参数不同进入不同的功能函数。

if (0 == strcmp(argv[1], "average")){
        average(argc, argv);
    } else if (0 == strcmp(argv[1], "yolo")){
        run_yolo(argc, argv);
    } else if (0 == strcmp(argv[1], "voxel")){
        run_voxel(argc, argv);
    } else if (0 == strcmp(argv[1], "super")){
        run_super(argc, argv);
    } else if (0 == strcmp(argv[1], "detector")){
        run_detector(argc, argv);
    } else if (0 == strcmp(argv[1], "detect")){
        float thresh = find_float_arg(argc, argv, "-thresh", .24);
        int ext_output = find_arg(argc, argv, "-ext_output");
        char *filename = (argc > 4) ? argv[4]: 0;
        test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, 0.5, 0, ext_output, 0, NULL, 0, 0);
    } else if (0 == strcmp(argv[1], "cifar")){
        run_cifar(argc, argv);
    } else if (0 == strcmp(argv[1], "go")){
        run_go(argc, argv);
    } else if (0 == strcmp(argv[1], "rnn")){
        run_char_rnn(argc, argv);
    } else if (0 == strcmp(argv[1], "vid")){
        run_vid_rnn(argc, argv);
    } else if (0 == strcmp(argv[1], "coco")){
        run_coco(argc, argv);
    } else if (0 == strcmp(argv[1], "classify")){
        predict_classifier("cfg/imagenet1k.data", argv[2], argv[3], argv[4], 5);
    } else if (0 == strcmp(argv[1], "classifier")){
        run_classifier(argc, argv);
    } else if (0 == strcmp(argv[1], "art")){
        run_art(argc, argv);
    } else if (0 == strcmp(argv[1], "tag")){
        run_tag(argc, argv);
    } else if (0 == strcmp(argv[1], "compare")){
        run_compare(argc, argv);
    } else if (0 == strcmp(argv[1], "dice")){
        run_dice(argc, argv);
    } else if (0 == strcmp(argv[1], "writing")){
        run_writing(argc, argv);
    } else if (0 == strcmp(argv[1], "3d")){
        composite_3d(argv[2], argv[3], argv[4], (argc > 5) ? atof(argv[5]) : 0);
    } else if (0 == strcmp(argv[1], "test")){
        test_resize(argv[2]);
    } else if (0 == strcmp(argv[1], "captcha")){
        run_captcha(argc, argv);
    } else if (0 == strcmp(argv[1], "nightmare")){
        run_nightmare(argc, argv);
    } else if (0 == strcmp(argv[1], "rgbgr")){
        rgbgr_net(argv[2], argv[3], argv[4]);
    } else if (0 == strcmp(argv[1], "reset")){
        reset_normalize_net(argv[2], argv[3], argv[4]);
    } else if (0 == strcmp(argv[1], "denormalize")){
        denormalize_net(argv[2], argv[3], argv[4]);
    } else if (0 == strcmp(argv[1], "statistics")){
        statistics_net(argv[2], argv[3]);
    } else if (0 == strcmp(argv[1], "normalize")){
        normalize_net(argv[2], argv[3], argv[4]);
    } else if (0 == strcmp(argv[1], "rescale")){
        rescale_net(argv[2], argv[3], argv[4]);
    } else if (0 == strcmp(argv[1], "ops")){
        operations(argv[2]);
    } else if (0 == strcmp(argv[1], "speed")){
        speed(argv[2], (argc > 3 && argv[3]) ? atoi(argv[3]) : 0);
    } else if (0 == strcmp(argv[1], "oneoff")){
        oneoff(argv[2], argv[3], argv[4]);
    } else if (0 == strcmp(argv[1], "partial")){
        partial(argv[2], argv[3], argv[4], atoi(argv[5]));
    } else if (0 == strcmp(argv[1], "visualize")){
        visualize(argv[2], (argc > 3) ? argv[3] : 0);
    } else if (0 == strcmp(argv[1], "imtest")){
        test_resize(argv[2]);
    } else {
        fprintf(stderr, "Not an option: %s\n", argv[1]);
    }

YOLO fastest测试命令:

./darknet detector  test ./data/voc.data  ./Yolo-Fastest/VOC/yolo-fastest-xl.cfg  ./Yolo-Fastest/VOC/yolo-fastest-xl.weights  data/dog.jpg  -thresh 0.55

训练命令:

./darknet detector train voc.data yolo-fastest.cfg yolo-fastest.conv.109 

由于argv[1]==“detector”,所以进入的分支函数是run_detector(),并将argv参数一并传入。

在run_detector()函数内,又会根据参数argv[2]的不同选择执行不同的分支:

if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile, letter_box, benchmark_layers);
    else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, mjpeg_port, show_imgs, benchmark_layers, chart_path);
    else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
    else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights);
    else if (0 == strcmp(argv[2], "map")) validate_detector_map(datacfg, cfg, weights, thresh, iou_thresh, map_points, letter_box, NULL);
    else if (0 == strcmp(argv[2], "calc_anchors")) calc_anchors(datacfg, num_of_clusters, width, height, show);
    else if (0 == strcmp(argv[2], "draw")) {
        int it_num = 100;
        draw_object(datacfg, cfg, weights, filename, thresh, dont_show, it_num, letter_box, benchmark_layers);
    }
    else if (0 == strcmp(argv[2], "demo")) {
        list *options = read_data_cfg(datacfg);
        int classes = option_find_int(options, "classes", 20);
        char *name_list = option_find_str(options, "names", "data/names.list");
        char **names = get_labels(name_list);
        if (filename)
            if (strlen(filename) > 0)
                if (filename[strlen(filename) - 1] == 0x0d) filename[strlen(filename) - 1] = 0;
        demo(cfg, weights, thresh, hier_thresh, cam_index, filename, names, classes, avgframes, frame_skip, prefix, out_filename,
            mjpeg_port, dontdraw_bbox, json_port, dont_show, ext_output, letter_box, time_limit_sec, http_post_host, benchmark, benchmark_layers);

        free_list_contents_kvp(options);
        free_list(options);
    }
    else printf(" There isn't such command: %s", argv[2]);

当第二个参数为test时,执行test_detector(),在这个函数内部:

1.network net = parse_network_cfg_custom(cfgfile, 1, 1);

建立网络,cfgfile为模型配置文件,batch设置为1(大于0),用来表明为detection模式分配内存,同时也代表了batch的数量,time_step设置为1,解释下这个参数,darknet是一个开源深度学习框架,它不仅仅支持yolov这一系列的目标检测,同时还可以分类,分割甚至RNN,LSTM等。只是后面这些功能用的人比较少,而且time_steps是RNN中的概念,即在更新梯度时,不仅仅考虑当前的输入的一个batch,而且还考虑前面time_steps个batch,所以才在后面又对net->batch做了乘time_steps的运算。由于这里是目标检测,不用考虑以前的历史数据,所以time_steps设置为1.

2.load_weights(&net, weightfile)

对建立好的网络结构加载权重文件。

3.load_image(input, 0, 0, net.c)

读取图像input为图像地址,net.c为输入层通道数,然后根据参数,选择缩放的方式:

if(letter_box) 
        sized = letterbox_image(im, net.w, net.h);
else 
        sized = resize_image(im, net.w, net.h);
4.float *X = sized.data;
network_predict(net, X)

模型进行推理,X为输入图像数据

5.detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letter_box)

通过网络提取出检测到的目标的位置以及类别。

6.draw_detections_v3(im, dets, nboxes, thresh, names, alphabet, l.classes, ext_output)

将目标的位置以及类别标注在图片中。

后面还有一些保存结果及释放内存的操作,至此一个完整的推理过程结束。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值