多分支梯度反向传播,concat后梯度如何反向传播

使用keras举例说明:

网络结构如下:

loss定义如下:

model.compile(loss="categorical_crossentropy",optimizer=OPT,metrics=["accuracy"],
            loss_weights={'main_output': 1., 'aux_output': 0.2}

网络具有两个输出,分别都要计算loss,那么在计算梯度反向传播的时候,就会有以下问题:

  1. 整个网络进行几次梯度更新,1次还是2次?
  2. VGG网络部分会受到Aux_out的梯度更新影响吗?
  3. concat操作之前的三个CNN网络,梯度是如何计算的?也就是说,网络有concat操作的时候,梯度更新如何作用于cancat之前的网络部分?

问题1

首先给出结论:在反向传播的时候,梯度更新只进行1次。

因为keras是基于计算图构建的网络,并且只有1个optimizer,optimizer也是计算图的一部分,所以最终loss计算是统一进行的,像上面的具有两个输出的网络,loss计算公式如下:

loss = (main_weight * main_loss) + (aux_weight * aux_loss) #you choose the weights in model.compile

其中

  • main_loss 是 function_of(main_true_output_data, main_model_output)
  • aux_loss 是 function_of(aux_true_output_data, aux_model_output)

网络所有的梯度计算仍然遵循最基本的 ∂(loss)/∂(weight_i)

因此,梯度更新只进行1次,计算方式如上所示。

问题2

  • 网路具有两个输出层,对应两个训练集,一个对应于Aux_out,另一个对应Main_out,因此通过fit定义 model.fit(inputs, [main_y, aux_y], ...)
  • 网路同样需要两个loss 函数,一个输出分支对应一个,其中main_loss对应main_out和main_y;Aux_loss对应Aux_out和Aux_y
  • 网络的loss是两个输出loss 的和,即loss = (main_weight * main_loss) + (aux_weight * aux_loss)
  • 网络一次反向传播更新1次梯度,所以整个网路的参数都通过上面的公式计算,因此,结论如下:
    • Aux_out分支只会影响concat以前的部分,main_out分支会影响整个网络。
    • 同时,Aux_out分支之前的网络部分,会同时受到两个分支的梯度更新影响,因此在下次前向传播的时候,VGG网络部分实际上也受到了上一次Aux_out分支更新梯度的影响。

问题3

concat是由三个CNN网络的输出拼接而成的,在反向传播的时候,会找到彼此的对应部分,有点像pooling操作的反向传播。

假设三个CNN输出的特征图尺寸完全相同,都是(H,W,C),concat后特征图为(H,W,3C),在反向传播计算到concat层的时候

  • Input1只会计算(H,W,[0 : C])部分的梯度
  • nput2只会计算(H,W,[C : 2C])部分的梯度
  • nput1只会计算(H,W,[2C : 3C])部分的梯度
  • 17
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值