YOLO源码(四)关于train_network()函数介绍

上一次了解了yolo是怎么读取txt路径文件的,这次主要了解一下train_network这个函数,另外.推荐大牛的github地址,对darknet讲解的很好!https://github.com/hgpvision/darknet

 float loss = train_network(net, train);

先简单清扫一下这个函数

float train_network(network net, data d)
{
    assert(d.X.rows % net.batch == 0);//断言非真,程序停止.d.X.row一次性加载到样本的个数.前面提过,darknet中的batch*net.subdivisions是打一次大的batch
    int batch = net.batch;
    int n = d.X.rows / batch;//算出一次大batch要跑多少次,如果大batch很大,cpu不能一次性跑完真么多图片,所以分几次跑,loss相加再取平均就行了

    int i;
    float sum = 0;
    for(i = 0; i < n; ++i){
        get_next_batch(d, batch, i*batch, net.input, net.truth);//从d中读取batch张图片到net.input中
        //第一个参数d包含了一次大batch的数据,也就是net.batch*net.subdivision张图片,第二个参数batch是每次循环读取到 net.input中的数据,参与训练图片的张数
        //第三个参数是d中偏移量,第四个参数为网络的输入数据,第五个参数为输入数据net.input对应的标签
        float err = train_network_datum(net);//训练网路,本次训练共有net.batch张图片
        //训练包括一次前向传播,计算每一层网络的输出并计算cost;一次反向,计算敏感度\权重更新值\偏置更新值\实时更新过程\更新权重和偏置
        sum += err;//err为loss,sum是总loss
    }
    return (float)sum/(n*batch);//算平均loss,就是一次大batch的loss了
}

看一下get_next_batch函数

void get_next_batch(data d, int n, int offset, float *X, float *y)
{
    //从d中读取batch张图片到net.input中
        //第一个参数d包含了一次大batch的数据,也就是net.batch*net.subdivision张图片,第二个参数batch是每次循环读取到 net.input中的数据,参与训练图片的张数
        //第三个参数是d中偏移量,第四个参数为网络的输入数据,第五个参数为输入数据net.input对应的标签
    int j;
    for(j = 0; j < n; ++j){
        int index = offset + j;
        memcpy(X+j*d.X.cols, d.X.vals[index], d.X.cols*sizeof(float));//d.X.cols是指样本的维度,把d.X.cols[index]拷贝到X+j*d.X.cols为头指针的内存里
        if(y) memcpy(y+j*d.y.cols, d.y.vals[index], d.y.cols*sizeof(float));//同样,拷贝标签
    }
}

重点就是train_network_datum(net)函数了

float train_network_datum(network net)
{
#ifdef GPU//如果使用GPU则调用Gpu版的train
    if(gpu_index >= 0) return train_network_datum_gpu(net);
#endif
    *net.seen += net.batch;//net.seen指的是已经训练过的图片的数量
    net.train = 1;
    forward_network(net);//重点
    backward_network(net);//重点
    float error = *net.cost;
    if(((*net.seen)/net.batch)%net.subdivisions == 0) update_network(net);
    return error;
}
下次重点介绍一下forward_network和backward_network函数
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值