- <span style="font-size:18px;">darknet.c
- 首先就是各种头文件
- /
- #include <time.h>
- #include <stdlib.h>
- #include <stdio.h>
- #include "parser.h"
- #include "utils.h"
- #include "cuda.h"
- #include "blas.h"
- #include "connected_layer.h"
- #ifdef OPENCV
- #include "opencv2/highgui/highgui_c.h"
- #endif
- /
- 各种接口函数的声明,方便下面调用
- void change_rate(char *filename, float scale, float add)
- void average(int argc, char *argv[])
- void speed(char *cfgfile, int tics)
- void operations(char *cfgfile)
- 。。。。。
- /
- 最主要的main函数
- int main(int argc, char **argv)
- {
- if(argc < 2){
- fprintf(stderr, "usage: %s <function>\n", argv[0]);
- return 0;
- }//参数小于2直接输出提示
- gpu_index = find_int_arg(argc, argv, "-i", 0);
- if(find_arg(argc, argv, "-nogpu")) {
- gpu_index = -1;
- }//设置无gpu格式
- #ifndef GPU
- gpu_index = -1;
- #else
- if(gpu_index >= 0){
- cuda_set_device(gpu_index);
- }
- #endif
- //输入选项
- if (0 == strcmp(argv[1], "average")){
- average(argc, argv);
- } else if (0 == strcmp(argv[1], "yolo")){
- run_yolo(argc, argv);//从这里跳转出去,执行yolo----------------
- } else if (0 == strcmp(argv[1], "voxel")){
- run_voxel(argc, argv);
- } else if (0 == strcmp(argv[1], "super")){
- run_super(argc, argv);
- } else if (0 == strcmp(argv[1], "detector")){
- run_detector(argc, argv);
- } else if (0 == strcmp(argv[1], "cifar")){
- run_cifar(argc, argv);
- } else if (0 == strcmp(argv[1], "go")){
- run_go(argc, argv);
- } else if (0 == strcmp(argv[1], "rnn")){
- run_char_rnn(argc, argv);
- } else if (0 == strcmp(argv[1], "vid")){
- run_vid_rnn(argc, argv);
- } else if (0 == strcmp(argv[1], "coco")){
- run_coco(argc, argv);
- } else if (0 == strcmp(argv[1], "classifier")){
- run_classifier(argc, argv);
- } else if (0 == strcmp(argv[1], "art")){
- run_art(argc, argv);
- } else if (0 == strcmp(argv[1], "tag")){
- run_tag(argc, argv);
- } else if (0 == strcmp(argv[1], "compare")){
- run_compare(argc, argv);
- } else if (0 == strcmp(argv[1], "dice")){
- run_dice(argc, argv);
- } else if (0 == strcmp(argv[1], "writing")){
- run_writing(argc, argv);
- } else if (0 == strcmp(argv[1], "3d")){
- composite_3d(argv[2], argv[3], argv[4], (argc > 5) ? atof(argv[5]) : 0);
- } else if (0 == strcmp(argv[1], "test")){
- test_resize(argv[2]);
- } else if (0 == strcmp(argv[1], "captcha")){
- run_captcha(argc, argv);
- } else if (0 == strcmp(argv[1], "nightmare")){
- run_nightmare(argc, argv);
- } else if (0 == strcmp(argv[1], "change")){
- change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0);
- } else if (0 == strcmp(argv[1], "rgbgr")){
- rgbgr_net(argv[2], argv[3], argv[4]);
- } else if (0 == strcmp(argv[1], "reset")){
- reset_normalize_net(argv[2], argv[3], argv[4]);
- } else if (0 == strcmp(argv[1], "denormalize")){
- denormalize_net(argv[2], argv[3], argv[4]);
- } else if (0 == strcmp(argv[1], "statistics")){
- statistics_net(argv[2], argv[3]);
- } else if (0 == strcmp(argv[1], "normalize")){
- normalize_net(argv[2], argv[3], argv[4]);
- } else if (0 == strcmp(argv[1], "rescale")){
- rescale_net(argv[2], argv[3], argv[4]);
- } else if (0 == strcmp(argv[1], "ops")){
- operations(argv[2]);
- } else if (0 == strcmp(argv[1], "speed")){
- speed(argv[2], (argc > 3) ? atoi(argv[3]) : 0);
- } else if (0 == strcmp(argv[1], "partial")){
- partial(argv[2], argv[3], argv[4], atoi(argv[5]));
- } else if (0 == strcmp(argv[1], "average")){
- average(argc, argv);
- } else if (0 == strcmp(argv[1], "visualize")){
- visualize(argv[2], (argc > 3) ? argv[3] : 0);
- } else if (0 == strcmp(argv[1], "imtest")){
- test_resize(argv[2]);
- } else {
- fprintf(stderr, "Not an option: %s\n", argv[1]);
- }
- return 0;
- }
- 在yolo.c中寻找到run_yolo函数
- void train_yolo(char *cfgfile, char *weightfile)
- void print_yolo_detections(FILE **fps, char *id, box *boxes, float **probs, int total, int classes, int w, int h)
- void validate_yolo(char *cfgfile, char *weightfile)
- void validate_yolo_recall(char *cfgfile, char *weightfile)
- void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
- void run_yolo(int argc, char **argv)
- {
- char *prefix = find_char_arg(argc, argv, "-prefix", 0);
- float thresh = find_float_arg(argc, argv, "-thresh", .2);
- int cam_index = find_int_arg(argc, argv, "-c", 0);
- int frame_skip = find_int_arg(argc, argv, "-s", 0);
- //提取输入参数,4个,格式如下所示
- if(argc < 4){
- fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
- return;
- }
- char *cfg = argv[3];
- char *weights = (argc > 4) ? argv[4] : 0;
- char *filename = (argc > 5) ? argv[5]: 0;
- if(0==strcmp(argv[2], "test")) test_yolo(cfg, weights, filename, thresh);
- //yolo测试图片
- else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights);
- else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights);
- else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights);
- else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, voc_names, 20, frame_skip, prefix);
- //yolo测试webcam demo
- }
- void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
- {
- image *alphabet = load_alphabet();
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- //加载网络权重
- detection_layer l = net.layers[net.n-1];
- set_batch_network(&net, 1);
- //设置网络
- srand(2222222);
- clock_t time;
- char buff[256];
- char *input = buff;
- int j;
- float nms=.4;
- box *boxes = calloc(l.side*l.side*l.n, sizeof(box));
- float **probs = calloc(l.side*l.side*l.n, sizeof(float *));
- for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));
- while(1){
- if(filename){
- strncpy(input, filename, 256);
- } else {
- printf("Enter Image Path: ");
- fflush(stdout);
- input = fgets(input, 256, stdin);
- if(!input) return;
- strtok(input, "\n");
- }
- image im = load_image_color(input,0,0);
- image sized = resize_image(im, net.w, net.h);
- float *X = sized.data;
- time=clock();//计时
- network_predict(net, X);//预测
- printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
- get_detection_boxes(l, 1, 1, thresh, probs, boxes, 0);
- if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);
- draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, alphabet, 20);//在原图像中画出边界框
- save_image(im, "predictions");//保存显示图像
- show_image(im, "predictions");
- free_image(im);
- free_image(sized);
- #ifdef OPENCV
- cvWaitKey(0);
- cvDestroyAllWindows();
- #endif
- if (filename) break;
- }
- }
- </span>
yolo测试代码梳理
最新推荐文章于 2024-08-30 13:28:37 发布