使用keras举例说明:
网络结构如下:
loss定义如下:
model.compile(loss="categorical_crossentropy",optimizer=OPT,metrics=["accuracy"],
loss_weights={'main_output': 1., 'aux_output': 0.2}
网络具有两个输出,分别都要计算loss,那么在计算梯度反向传播的时候,就会有以下问题:
- 整个网络进行几次梯度更新,1次还是2次?
- VGG网络部分会受到Aux_out的梯度更新影响吗?
- concat操作之前的三个CNN网络,梯度是如何计算的?也就是说,网络有concat操作的时候,梯度更新如何作用于cancat之前的网络部分?
问题1
首先给出结论:在反向传播的时候,梯度更新只进行1次。
因为keras是基于计算图构建的网络,并且只有1个optimizer,optimizer也是计算图的一部分,所以最终loss计算是统一进行的,像上面的具有两个输出的网络,loss计算公式如下:
loss = (main_weight * main_loss) + (aux_weight * aux_loss) #you choose the weights in model.compile
其中
- main_loss 是 function_of(main_true_output_data, main_model_output)
- aux_loss 是 function_of(aux_true_output_data, aux_model_output)
网络所有的梯度计算仍然遵循最基本的 ∂(loss)/∂(weight_i)
因此,梯度更新只进行1次,计算方式如上所示。
问题2
- 网路具有两个输出层,对应两个训练集,一个对应于Aux_out,另一个对应Main_out,因此通过fit定义 model.fit(inputs, [main_y, aux_y], ...)
- 网路同样需要两个loss 函数,一个输出分支对应一个,其中main_loss对应main_out和main_y;Aux_loss对应Aux_out和Aux_y
- 网络的loss是两个输出loss 的和,即loss = (main_weight * main_loss) + (aux_weight * aux_loss)
- 网络一次反向传播更新1次梯度,所以整个网路的参数都通过上面的公式计算,因此,结论如下:
-
- Aux_out分支只会影响concat以前的部分,main_out分支会影响整个网络。
- 同时,Aux_out分支之前的网络部分,会同时受到两个分支的梯度更新影响,因此在下次前向传播的时候,VGG网络部分实际上也受到了上一次Aux_out分支更新梯度的影响。
问题3
concat是由三个CNN网络的输出拼接而成的,在反向传播的时候,会找到彼此的对应部分,有点像pooling操作的反向传播。
假设三个CNN输出的特征图尺寸完全相同,都是(H,W,C),concat后特征图为(H,W,3C),在反向传播计算到concat层的时候
- Input1只会计算(H,W,[0 : C])部分的梯度
- nput2只会计算(H,W,[C : 2C])部分的梯度
- nput1只会计算(H,W,[2C : 3C])部分的梯度