loss = train_network(net, train)backward_network(net)

29 篇文章 1 订阅
27 篇文章 2 订阅

train_network(net, train)函数主体

float train_network(network *net, data d)
{
    assert(d.X.rows % net->batch == 0);//判断data d.X中保存的训练图像数据能否被min_batch整除
    int batch = net->batch;//将min_batch赋给batch,表示每次训练的数据
    int n = d.X.rows / batch;//d.X中的训练数据分几次训练
    int i;
    float sum = 0;//用来保存损失值
    for(i = 0; i < n; ++i){//遍历训练所有的min_batch
        get_next_batch(d, batch, i*batch, net->input, net->truth);//将每次训练的数据导入网络中去,每次训练min_batch数据
        /*
		void get_next_batch(data d, int n, int offset, float *X, float *y)
		{
		    int j;
		    for(j = 0; j < n; ++j){//遍历每张图像的数据和标签,将其导入网络中
		        int index = offset + j;//每批min_batch的偏移目录地址
		        memcpy(X+j*d.X.cols, d.X.vals[index], d.X.cols*sizeof(float));//将d.X.vals[index]中的d.X.cols个float类型的数据复制到X+j*d.X.cols中去,即导入到net->input中;这里d.X.cols就是每一幅图像的数据量,即高*宽*通道,我的是设置的448x448x3=602112
		        if(y) memcpy(y+j*d.y.cols, d.y.vals[index], d.y.cols*sizeof(float));//同上,将与之对应的标签导入到net->truth中,这里我设置了一幅图上最大的检测框数量为90,每一个检测框包括5个标签数据,所以每幅图上的最大标签数据数量为90x5=450;
		    }
		}
		*/
        float err = train_network_datum(net);//核心训练函数
        sum += err;//将每次min_batch训练得出的损失累加
    }
    return (float)sum/(n*batch);//得到一个大的batch的平均损失值
}

float err = train_network_datum(net);函数主体

float train_network_datum(network *net)
{
    *net->seen += net->batch;//记录训练了多少图像
    net->train = 1;
    forward_network(net);//前向网络参考[博客](https://blog.csdn.net/m0_37799466/article/details/106192969)
    backward_network(net);//返回后向网络[博客](https://blog.csdn.net/m0_37799466/article/details/106195131)
    float error = *net->cost;//一次min_batch的损失值
    if(((*net->seen)/net->batch)%net->subdivisions == 0) update_network(net);//每训练一个大batch就更新一次参数,即net->batch*net->subdivisions,关注[博客](https://blog.csdn.net/m0_37799466/article/details/106285677)
    return error;
}
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值