Conformer 示例代码实现

Conformer 超详细讲解:https://mp.weixin.qq.com/s?__biz=Mzk0MzIzODM5MA==&mid=2247487889&idx=1&sn=b2d9274cbb83d409c3cc4326583ef6b8&chksm=c337ac08f440251e62d95a55cf4079dbc14e742d860b0d6a5d07f27e8a1e1f3535700f1277f4#rd

具体代码参考:https://github.com/sooftware/conformer

示例实现:

注意,我这里为了简单起见有几处简化:注意力中没有相对编码、简化了卷积的实现
想看具体实现的同学可以参考上面链接

import torch
import torch.nn as nn


class FeedForwardModule(nn.Module):
    def __init__(self, dim, expansion_factor=4, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * expansion_factor),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * expansion_factor, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return x + 0.5 * self.net(x)


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout)
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x = self.norm(x)
        x, _ = self.attention(x, x, x)
        return residual + self.dropout(x)


class ConvolutionModule(nn.Module):
    def __init__(self, dim, kernel_size=31, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(dim)
        self.conv_net = nn.Sequential(
            nn.Conv1d(dim, 2 * dim, kernel_size=1),
            nn.GLU(dim=1),
            nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim),
            nn.BatchNorm1d(dim),
            nn.SiLU(),
            nn.Conv1d(dim, dim, kernel_size=1),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        residual = x
        x = self.layer_norm(x)
        x = x.transpose(1, 2)
        x = self.conv_net(x)
        x = x.transpose(1, 2)
        return residual + x


class ConformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ff_expansion_factor=4, conv_kernel_size=31, dropout=0.1):
        super().__init__()
        self.ff1 = FeedForwardModule(dim, ff_expansion_factor, dropout)
        self.attention = MultiHeadSelfAttention(dim, num_heads, dropout)
        self.conv = ConvolutionModule(dim, conv_kernel_size, dropout)
        self.ff2 = FeedForwardModule(dim, ff_expansion_factor, dropout)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = self.ff1(x)
        x = self.attention(x)
        x = self.conv(x)
        x = self.ff2(x)
        return self.norm(x)


class ConvolutionSubsampling(nn.Module):
    def __init__(self, input_dim, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, dim, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(dim, dim, kernel_size=3, stride=2),
            nn.ReLU(),
        )
        self.linear = nn.Linear(dim * ((input_dim - 1) // 4), dim)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.net(x)
        B, C, T, F = x.size()
        x = x.permute(0, 2, 1, 3).reshape(B, T, C * F)
        return self.linear(x)


class ConformerEncoder(nn.Module):
    def __init__(self, input_dim, num_classes, num_layers=4, dim=256, num_heads=4, ff_expansion_factor=4, conv_kernel_size=31, dropout=0.1):
        super().__init__()
        self.subsampling = ConvolutionSubsampling(input_dim, dim)
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.Sequential(
            *[ConformerBlock(dim, num_heads, ff_expansion_factor, conv_kernel_size, dropout) for _ in range(num_layers)]
        )
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.subsampling(x)
        x = self.dropout(x)
        x = self.blocks(x)
        return self.fc(x)
batch_size = 16
seq_len = 128
input_dim = 80  # e.g., Mel-spectrogram features
num_classes = 100  # Vocabulary size

model = ConformerEncoder(input_dim, num_classes, num_layers=4, dim=256, num_heads=4)
inputs = torch.randn(batch_size, seq_len, input_dim)
outputs = model(inputs)
print("[DEBUG] Output shape:", outputs.shape)  # [batch_size, seq_len // 4, num_classes]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WGS.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值