多模态学习中四种常用的跨模态特征融合方法定义与PyTorch实现

本文介绍了四种特征融合方法在深度学习中的应用,包括SumFusion(直接相加),ConcatFusion(堆叠后通过全连接层),FiLM(Feature-wiseLinearModulation)和GatedFusion(门控融合)。这些方法主要用于不同模态数据的融合,提高模型的表示能力和推理性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文共介绍四种方法,分别是SumFusion、ConcatFusion、FiLM以及GatedFusion

FiLM参考paper-FiLM: Visual Reasoning with a General Conditioning Layer

GatedFusion参考paper-Efficient Large-Scale Multi-Modal Classification

import torch
import torch.nn as nn

#------------------------------------------#
# SumFusion的定义,为两者过全连接层后进行直接相加
#------------------------------------------#
class SumFusion(nn.Module):
    def __init__(self, input_dim=512, output_dim=100):
        super(SumFusion, self).__init__()
        #---------------------------------------#
        # 针对x以及y两个特征张量,分别定义了两个全连接层
        #---------------------------------------#
        self.fc_x = nn.Linear(input_dim, output_dim)
        self.fc_y = nn.Linear(input_dim, output_dim)

    def forward(self, x, y):
        output = self.fc_x(x) + self.fc_y(y)
        return x, y, output

#------------------------------------------#
# ConcatFusion的定义,只定义一个全连接层
# 首先将两者堆叠,之后再将堆叠后的向量送入至全连接层
#------------------------------------------#
class ConcatFusion(nn.Module):
    def __init__(self, input_dim=1024, output_dim=100):
        super(ConcatFusion, self).__init__()
        self.fc_out = nn.Linear(input_dim, output_dim)

    def forward(self, x, y):
        output = torch.cat((x, y), dim=1)
        output = self.fc_out(output)
        return x, y, output

#------------------------------------------#
# FiLM融合方法的定义,只定义一个全连接层
#------------------------------------------#
class FiLM(nn.Module):
    """
    FiLM: Visual Reasoning with a General Conditioning Layer,
    https://arxiv.org/pdf/1709.07871.pdf.
    """
    def __init__(self, input_dim=512, dim=512, output_dim=100, x_film=True):
        super(FiLM, self).__init__()
        self.dim    = input_dim
        self.fc     = nn.Linear(input_dim, 2 * dim)
        self.fc_out = nn.Linear(dim, output_dim)
        self.x_film = x_film

    def forward(self, x, y):
        if self.x_film:
            film = x
            to_be_film = y
        else:
            film = y
            to_be_film = x

        gamma, beta = torch.split(self.fc(film), self.dim, 1)

        output = gamma * to_be_film + beta
        output = self.fc_out(output)

        return x, y, output

#------------------------------------------#
# GatedFusion方法的定义
#------------------------------------------#
class GatedFusion(nn.Module):
    """
    Efficient Large-Scale Multi-Modal Classification,
    https://arxiv.org/pdf/1802.02892.pdf.
    """

    def __init__(self, input_dim=512, dim=512, output_dim=100, x_gate=True):
        super(GatedFusion, self).__init__()
        self.fc_x    = nn.Linear(input_dim, dim)
        self.fc_y    = nn.Linear(input_dim, dim)
        self.fc_out  = nn.Linear(dim, output_dim)
        self.x_gate  = x_gate  # whether to choose the x to obtain the gate
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        out_x = self.fc_x(x)
        out_y = self.fc_y(y)

        if self.x_gate:
            gate   = self.sigmoid(out_x)
            output = self.fc_out(torch.mul(gate, out_y))
        else:
            gate   = self.sigmoid(out_y)
            output = self.fc_out(torch.mul(out_x, gate))

        return out_x, out_y, output

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

XuecWu3

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值