Keras同时有多个输出时损失函数计算方法和反向传播过程

来源:https://stackoverflow.com/questions/57149476/how-is-a-multiple-outputs-deep-learning-model-trained

Keras calculations are graph based and use only one optimizer.

The optimizer is also a part of the graph, and in its calculations it gets the gradients of the whole group of weights. (Not two groups of gradients, one for each output, but one group of gradients for the entire model).

Mathematically, it's not really complicated, you have a final loss function made of:

loss = (main_weight * main_loss) + (aux_weight * aux_loss) #you choose the weights in model.compile

All defined by you. Plus a series of other possible weights (sample weights, class weights, regularizer terms, etc.)

Where:

  • main_loss is a function_of(main_true_output_data, main_model_output)
  • aux_loss is a function_of(aux_true_output_data, aux_model_output)

And the gradients are just ∂(loss)/∂(weight_i) for all weights.

Once the optimizer has the gradients, it performs its optimization step once.

Questions:

how are the auxiliary branch weights updated as it is not connected directly to the main output?

  • You have two output datasets. One dataset for main_output and another dataset for aux_output. You must pass them to fit in model.fit(inputs, [main_y, aux_y], ...)
  • You also have two loss functions, one for each, where main_loss takes main_y and main_out; and aux_loss takex aux_y and aux_out.
  • The two losses are summed: loss = (main_weight * main_loss) + (aux_weight * aux_loss)
  • The gradients are calculated for the function loss once, and this function connects to the entire model.
    • The aux term will affect lstm_1 and embedding_1 in backpropagation.
    • Consequently, in the next forward pass (after weights are updated) it will end up influencing the main branch. (If it will be better or worse only depends on whether the aux output is useful or not)

Is the part of the network which is between the root of the auxiliary branch and the main output concerned by the the weighting of the loss? Or the weighting influences only the part of the network that is connected to the auxiliary output?

The weights are plain mathematics. You will define them in compile:

model.compile(optimizer=one_optimizer, 

              #you choose each loss   
              loss={'main_output':main_loss, 'aux_output':aux_loss},

              #you choose each weight
              loss_weights={'main_output': main_weight, 'aux_output': aux_weight}, 

              metrics = ...)

And the loss function will use them in loss = (weight1 * loss1) + (weight2 * loss2).
The rest is the mathematical calculation of ∂(loss)/∂(weight_i) for each weight.

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值