darknet源码剖析(三)

进入train_detector函数。

list *options = read_data_cfg(datacfg);

list的定义位于darknet.h,为链表。read_data_cfg位于option_list.c中。

read_data_cfg的作用在于将数据集配置转化为链表。

    char *train_images = option_find_str(options, "train", "data/train.list");
    char *backup_directory = option_find_str(options, "backup", "/backup/");

option_find_str用于寻找选项指定的内容,若没有指定则使用默认值。在当前程序执行环境中“train”与“backup”均已指定。

    srand(time(0));
    char *base = basecfg(cfgfile);
    printf("%s\n", base);
    float avg_loss = -1;
    network **nets = calloc(ngpus, sizeof(network));

其中比较重要的basecfg(cfgfile),这一句的作用在于提取模型文件的名称(“.”之前的部分)。在当前程序中是“yolov3-voc”。

network **nets = calloc(ngpus, sizeof(network));用于为网络分配内存空间。network的定义为darknet.h中。为每个gpu分配一个network。

    srand(time(0));
    int seed = rand();
    int i;
    for(i = 0; i < ngpus; ++i){
        srand(seed);
#ifdef GPU
        cuda_set_device(gpus[i]);
#endif
        nets[i] = load_network(cfgfile, weightfile, clear);
        nets[i]->learning_rate *= ngpus;
    }
    srand(time(0));
    network *net = nets[0];

若在makefile中设置了GPU,则执行cuda_set_device,该函数位于cuda.c文件中。

load_network函数位于network.c文件中。

network *load_network(char *cfg, char *weights, int clear)
{
    network *net = parse_network_cfg(cfg);
    if(weights && weights[0] != 0){
        load_weights(net, weights);
    }
    if(clear) (*net->seen) = 0;
    return net;
}

parse_network_cfg函数用于解析模型配置文件,load_weights函数用于加载预训练参数。在此不详细分析,load_network函数执行完毕后返回yolov3模型。

learning_rate*ngpus的作用暂不清楚。

    srand(time(0));
    network *net = nets[0];

    int imgs = net->batch * net->subdivisions * ngpus;
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
    data train, buffer;

imgs可能是图片的总张数。

    layer l = net->layers[net->n - 1];

    int classes = l.classes;
    float jitter = l.jitter;

net->n代表网络的总层数,net->layers[net->n-1]代表网络最后一层,net从0开始计数。

layer l的内容结合yolov3-voc.cfg的内容可以知道,

[yolo]
mask = 0,1,2
anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
classes=20
num=9
jitter=.3
ignore_thresh = .5
truth_thresh = 1
random=1

因此classes为20,jitter为0.3。

    list *plist = get_paths(train_images);
    //int N = plist->size;
    char **paths = (char **)list_to_array(plist);

get_paths将train_images文件中的训练数据转化为list,train_images文件为txt格式,存储的是所有训练数据存储的地址。

list_to_array将list转化为二维字符矩阵,用于存储训练数据存储的地址。

    load_args args = get_base_args(net);
    args.coords = l.coords;
    args.paths = paths;
    args.n = imgs;
    args.m = plist->size;
    args.classes = classes;
    args.jitter = jitter;
    args.num_boxes = l.max_boxes;
    args.d = &buffer;
    args.type = DETECTION_DATA;
    //args.type = INSTANCE_DATA;
    args.threads = 64;

上述代码的功能是设置模型参数

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值