具体代码参考:https://github.com/sooftware/conformer
示例实现:
注意,我这里为了简单起见有几处简化:注意力中没有相对编码、简化了卷积的实现
想看具体实现的同学可以参考上面链接
import torch
import torch.nn as nn
class FeedForwardModule(nn.Module):
def __init__(self, dim, expansion_factor=4, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * expansion_factor, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return x + 0.5 * self.net(x)
class MultiHeadSelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout)
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.norm(x)
x, _ = self.attention(x, x, x)
return residual + self.dropout(x)
class ConvolutionModule(nn.Module):
def __init__(self, dim, kernel_size=31, dropout=0.1):
super().__init__()
self.layer_norm = nn.LayerNorm(dim)
self.conv_net = nn.Sequential(
nn.Conv1d(dim, 2 * dim, kernel_size=1),
nn.GLU(dim=1),
nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim),
nn.BatchNorm1d(dim),
nn.SiLU(),
nn.Conv1d(dim, dim, kernel_size=1),
nn.Dropout(dropout),
)
def forward(self, x):
residual = x
x = self.layer_norm(x)
x = x.transpose(1, 2)
x = self.conv_net(x)
x = x.transpose(1, 2)
return residual + x
class ConformerBlock(nn.Module):
def __init__(self, dim, num_heads, ff_expansion_factor=4, conv_kernel_size=31, dropout=0.1):
super().__init__()
self.ff1 = FeedForwardModule(dim, ff_expansion_factor, dropout)
self.attention = MultiHeadSelfAttention(dim, num_heads, dropout)
self.conv = ConvolutionModule(dim, conv_kernel_size, dropout)
self.ff2 = FeedForwardModule(dim, ff_expansion_factor, dropout)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.ff1(x)
x = self.attention(x)
x = self.conv(x)
x = self.ff2(x)
return self.norm(x)
class ConvolutionSubsampling(nn.Module):
def __init__(self, input_dim, dim):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, dim, kernel_size=3, stride=2),
nn.ReLU(),
nn.Conv2d(dim, dim, kernel_size=3, stride=2),
nn.ReLU(),
)
self.linear = nn.Linear(dim * ((input_dim - 1) // 4), dim)
def forward(self, x):
x = x.unsqueeze(1)
x = self.net(x)
B, C, T, F = x.size()
x = x.permute(0, 2, 1, 3).reshape(B, T, C * F)
return self.linear(x)
class ConformerEncoder(nn.Module):
def __init__(self, input_dim, num_classes, num_layers=4, dim=256, num_heads=4, ff_expansion_factor=4, conv_kernel_size=31, dropout=0.1):
super().__init__()
self.subsampling = ConvolutionSubsampling(input_dim, dim)
self.dropout = nn.Dropout(dropout)
self.blocks = nn.Sequential(
*[ConformerBlock(dim, num_heads, ff_expansion_factor, conv_kernel_size, dropout) for _ in range(num_layers)]
)
self.fc = nn.Linear(dim, num_classes)
def forward(self, x):
x = self.subsampling(x)
x = self.dropout(x)
x = self.blocks(x)
return self.fc(x)
batch_size = 16
seq_len = 128
input_dim = 80 # e.g., Mel-spectrogram features
num_classes = 100 # Vocabulary size
model = ConformerEncoder(input_dim, num_classes, num_layers=4, dim=256, num_heads=4)
inputs = torch.randn(batch_size, seq_len, input_dim)
outputs = model(inputs)
print("[DEBUG] Output shape:", outputs.shape) # [batch_size, seq_len // 4, num_classes]