OGM-GE(动态梯度调制&泛化增强)

目录

概念解析

OGM (On-the-fly Gradient Modulation)

GE (Generalization Enhancement)

总结


概念解析

OGM (On-the-fly Gradient Modulation)

 On-the-fly Gradient Modulation (动态梯度调制), 简单来说,就是让不同模态间的梯度下降尽可能地贴近,避免模态间学习差异过大与部分模态学习进度过慢。

原文: Balanced Multimodal Learning via On-the-fly Gradient Modulation

链接: https://arxiv.org/abs/2203.15332icon-default.png?t=N7T8https://arxiv.org/abs/2203.15332

图例

计算步骤(因为图像学习更快所以这里取audio作为分母)

1. 对编码器输出特征进行FC(bias取二分之一),并使用softmax函数计算对应模态学习程度s

        根据不同的融合方式计算每个模态所求得对应最终输出预测logit的贡献,例:如是sum则将FC的weight直接乘上输入进FC的模态向量得到对应的贡献;如是concat则将weight平分乘上不同模态的向量+二分之一的偏置(无论融合方式如何所计算的两个模态的logit相加都是最终预测结果) 

if args.fusion_method == 'sum':
            out_v = (torch.mm(v, torch.transpose(model.module.fusion_module.fc_y.weight, 0,                 
                     1)) +model.module.fusion_module.fc_y.bias)
           out_a = (torch.mm(a, torch.transpose(model.module.fusion_module.fc_x.weight, 0, 
                     1)) +model.module.fusion_module.fc_x.bias)
        else:
            weight_size = model.module.fusion_module.fc_out.weight.size(1)
            out_v = (torch.mm(v,torch.transpose(model.module.fusion_module.fc_out.weight[:, 
                     weight_size // 2:], 0, 1))
                     + model.module.fusion_module.fc_out.bias / 2)
            out_a = (torch.mm(a,torch.transpose(model.module.fusion_module.fc_out.weight[:, 
                     :weight_size // 2], 0, 1))+ model.module.fusion_module.fc_out.bias /2)

2. 计算学习差异比率(谁学习程度高谁降低)

将每个模态预测的logit经过softmax之后将对应模态正样本的概率求和计算得分 s

score_v = sum([softmax(out_v)[i][label[i]] for i in range(out_v.size(0))])
score_a = sum([softmax(out_a)[i][label[i]] for i in range(out_a.size(0))])

ratio_v = score_v / score_a
ratio_a = 1 / ratio_v

3. 根据这个比率来调节不同模态的系数

            if ratio_v > 1:
                coeff_v = 1 - tanh(args.alpha * relu(ratio_v))
                coeff_a = 1
            else:
                coeff_a = 1 - tanh(args.alpha * relu(ratio_a))
                coeff_v = 1

4. 动态降低模态(学习程度高)的梯度

更新梯度的同时加上高斯分布噪音 

            if args.modulation_starts <= epoch <= args.modulation_ends: # bug fixed
                for name, parms in model.named_parameters():
                    layer = str(name).split('.')[1]

                    if 'audio' in layer and len(parms.grad.size()) == 4:
                        if args.modulation == 'OGM_GE':  # bug fixed
                            parms.grad = parms.grad * coeff_a + \
                                         torch.zeros_like(parms.grad).normal_(0,                                                       parms.grad.std().item() + 1e-8)
                        elif args.modulation == 'OGM':
                            parms.grad *= coeff_a

                    if 'visual' in layer and len(parms.grad.size()) == 4:
                        if args.modulation == 'OGM_GE':  # bug fixed
                            parms.grad = parms.grad * coeff_v + \
                                         torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8)
                        elif args.modulation == 'OGM':
                            parms.grad *= coeff_v
            else:
                pass

GE (Generalization Enhancement

Generalization Enhancement (泛化增强),因为在OGM中降低了学习率,容易使模型陷入局部最优解,造成过拟合现象,所以通过增加高斯噪音的方法,提高泛化能力,防止过拟合。


总结

1. 该模型的初始阶段不如正常模型,因为降低了学习率

2. 虽然通过OGM关联多模态之间的学习情况,最终仍然存在差异

3.在CREMA-D 和 VGGSounddataset数据集上,通过结合OGM-GE,效果较为显著

4.这个辅助模型最大的问题应该是他的一个适用性, 因为他在融合后只有一个线性层转换为预测值所以在区分每个模态的贡献时非常的清晰,如果在融合后用一些计算复杂的层比如说自注意力块这种的就很难去区分每个模态对最终预测的贡献,也就不好去计算每个模态的学习程度,如果能处理好这个问题那发个OGM-GE++都行了。笔者认为区分不同模态在复杂模型中对于下游任务的贡献也是非常值得关注的问题。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值