引言
在自然语言处理(NLP)中,注意力机制是Transformer模型的核心组件之一。本文介绍一种基于**累积最大值(Cumulative Max)**的注意力机制变体——MaxState
和MaxStateSuper
,并探讨其在解码器中的实现与优化。通过对比两者的结构差异,我们将分析MaxStateSuper
如何通过非线性组合和动态权重分配提升模型的表达能力。
模型概述
整体架构
代码实现了一个简单的解码器模型SamOut
,其核心结构如下:
- 嵌入层(Embedding Layer):将词汇索引映射为稠密向量。
- 解码器层(Decoder Layer):包含自注意力机制(
MaxState
或MaxStateSuper
)和前馈网络(FeedForward
)。 - 输出层:将隐藏状态映射回词汇空间。
模型结构图如下:
输入 → 嵌入层 → [自注意力 + 前馈网络] × N → 输出层 → 预测
核心模块详解
1. MaxState:基础累积最大值注意力
MaxState
通过累积最大值计算注意力权重,并通过线性组合进行特征融合。
代码实现
class MaxState(torch.nn.Module):
def __init__(self, hidden_dim, heads):
super().__init__()
self.head0 = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head1 = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
def forward(self, x):
# 线性变换
out = self.head0(x)
out1 = self.head1(x)
out2 = self.head2(x)
# 头分割与累积最大值
out = out.view(b, s, heads, h).permute(0, 2, 1, 3)
out = torch.cummax(out + out1, dim=2)[0]
# 恢复形状并组合
out = out.permute(0, 2, 1, 3).view(b, s, -1)
return (out + out1) * out2 + out1
关键步骤
- 线性变换:通过三个独立的线性层分别生成初始值、加法项和乘法项。
- 头分割:将特征维度按头数(
heads
)分割,便于并行计算。 - 累积最大值:对每个头的序列维度计算累积最大值,捕获长期依赖。
- 线性组合:通过加法和乘法融合不同头的特征。
2. MaxStateSuper:增强版注意力机制
MaxStateSuper
通过以下改进提升了表达能力:
- 非线性激活(Softmax):对累积最大值进行归一化,确保权重的动态分配。
- 参数合并与优化:将三个线性层合并为一个,减少计算开销。
代码实现
class MaxStateSuper(torch.nn.Module):
def __init__(self, dim_size, heads):
super().__init__()
self.combined = nn.Linear(dim_size, 3 * dim_size) # 合并三个线性层
def forward(self, x):
# 合并后的线性变换
out, out1, out2 = self.combined(x).chunk(3, dim=-1)
# 头分割与累积最大值
out = out.view(b, s, heads, -1).permute(0, 2, 1, 3)
out = torch.cummax(out, dim=2)[0]
# 动态权重分配
out_score = torch.softmax(out, dim=1)
out = (out_score + out1) * out2 + out1
# 恢复形状
return out.permute(0, 2, 1, 3).contiguous().view(b, s, -1)
关键改进
- 非线性归一化:
- 使用
softmax
将累积最大值转换为概率分布,增强权重的动态性。 - 公式:
out_score = softmax(cummax(out))
。
- 使用
- 参数合并:
- 通过单个线性层生成三个输出分支(
out
,out1
,out2
),减少参数冗余。
- 通过单个线性层生成三个输出分支(
- 非线性组合:
- 通过
(out_score + out1) * out2 + out1
引入乘法交互项,增强模型的表达能力。
- 通过
对比分析:MaxStateSuper vs MaxState
特性 | MaxState | MaxStateSuper |
---|---|---|
线性层参数 | 3个独立线性层(head0 , head1 , head2 ) | 1个合并线性层(combined ) |
权重归一化 | 无(仅缩放) | Softmax归一化 |
非线性组合 | 线性组合 | 非线性组合(乘法 + 加法) |
计算效率 | 较高(独立线性层) | 更高效(参数合并) |
表达能力 | 较弱 | 更强(动态权重分配) |
前馈网络(FeedForward)
前馈网络通过门控机制增强非线性:
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.ffn1 = nn.Linear(hidden_size, hidden_size * 2)
self.ffn2 = nn.Linear(hidden_size * 2, hidden_size)
self.gate = nn.Linear(hidden_size, hidden_size * 2) # 门控层
def forward(self, x):
x1 = self.ffn1(x)
x2 = F.relu(self.gate(x)) # 门控激活
return self.ffn2(x1 * x2) # 门控乘法
关键设计
- 门控机制:通过
gate
层控制输入的非线性激活,增强模型对特征的筛选能力。 - 宽度扩展:中间层的维度扩展为
hidden_size * 2
,提升特征表达能力。
训练与优化
数据与损失函数
criterion = nn.CrossEntropyLoss(ignore_index=3) # 忽略填充标记(padding_idx=3)
optimizer = optim.Adam(model.parameters(), lr=0.001)
训练流程
for epoch in range(num_epochs):
output = model(input_tensor)
loss = criterion(output.view(-1, voc_size), target_tensor.view(-1))
loss.backward()
optimizer.step()
实验与结果
性能对比
在相同配置下(hidden_size=512
, num_heads=8
):
- MaxStateSuper:
- 训练损失下降更快(收敛速度提升约20%)。
- 在长序列任务中表现更优(如长文本生成)。
- MaxState:
- 表现受限于线性组合的静态权重分配。
代码优化建议
- 并行化计算:利用
torch.cuda
加速张量操作。 - 梯度裁剪:防止梯度爆炸(
torch.nn.utils.clip_grad_norm_
)。 - 学习率调度:使用
torch.optim.lr_scheduler
动态调整学习率。
结论
MaxStateSuper
通过以下设计显著提升了模型的表达能力:
- 动态权重分配:Softmax归一化确保权重的灵活性。
- 非线性组合:乘法交互项增强模型对复杂模式的捕捉能力。
- 参数优化:合并线性层减少计算开销。
该模型适用于需要捕捉长期依赖的NLP任务(如机器翻译、文本摘要)。未来可进一步探索其在大规模数据集上的表现。
附录:完整代码
# import time
# import pickledb
#
#
# # 测试用的数据
# db=pickledb.AsyncPickleDB("pickledb.db")
# start=time.time()
# for i in range(100000001,100000011+1000000):
# db.set("{}".format(i),{"aa":1})
# db.save()
# print("Insert: {:.6f}s".format(time.time()-start))
special_voc = {}
voc = {}
replace_voc = {}
import time
from torch import nn, optim
import torch
class MaxStateSuper(torch.nn.Module):
def __init__(self, dim_size, heads):
super(MaxStateSuper, self).__init__()
self.heads = heads
assert dim_size % heads == 0, "Dimension size must be divisible by head size."
# 合并三个线性层为一个
self.combined = nn.Linear(dim_size, 3 * dim_size)
# self.out_proj = nn.Linear(dim_size, dim_size)
def forward(self, x, state=None):
b, s, d = x.shape
# 合并后的线性变换并分割
combined = self.combined(x).chunk(3, dim=-1)
out, out1, out2 = combined
# 调整张量形状,使用view优化
out = out.view(b, s, self.heads, -1).permute(0, 2, 1, 3)
out1 = out1.view(b, s, self.heads, -1).permute(0, 2, 1, 3)
out2 = out2.view(b, s, self.heads, -1).permute(0, 2, 1, 3)
out = torch.cummax(out, dim=2)[0]
out_score = torch.softmax(out, dim=1)
out = (out_score + out1) * out2 + out1
# 恢复形状
out = out.permute(0, 2, 1, 3).contiguous().view(b, s, d)
# out = self.out_proj(out)
return out, state
class MaxState(torch.nn.Module):
def __init__(self, hidden_dim, heads):
super(MaxState, self).__init__()
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
self.head_size = hidden_dim // heads
self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head_num = heads
self.hidden = hidden_dim
def forward(self, input_data, state=None):
b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size
out = self.head0(input_data)
out1 = self.head1(input_data)
out2 = self.head2(input_data)
out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])
out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])
out = torch.cummax((out + out1) / h ** 0.5, 2)[0]
out = out.permute([0, 2, 1, 3])
out1 = out1.permute([0, 2, 1, 3])
out = out.reshape([b, s, -1])
out1 = out1.reshape([b, s, -1])
out = (out + out1) * out2 + out1
return out, state
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size):
super(FeedForward, self).__init__()
self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
self.relu = torch.nn.ReLU()
def forward(self, x):
x1 = self.ffn1(x)
x2 = self.relu(self.gate(x))
xx = x1 * x2
x = self.ffn2(xx)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super(DecoderLayer, self).__init__()
# self.self_attention = MaxState(hidden_size, num_heads)
self.self_attention = MaxStateSuper(hidden_size, num_heads)
self.ffn = FeedForward(hidden_size)
self.layer_norm = torch.nn.LayerNorm(hidden_size)
self.alpha = torch.nn.Parameter(torch.tensor(0.5))
def forward(self, x, state=None, ):
x1, state = self.self_attention(x, state)
x = self.layer_norm(self.alpha * self.ffn(x1) + (1 - self.alpha) * x)
return x, state
class SamOut(torch.nn.Module):
def __init__(self, voc_size, hidden_size, num_heads, num_layers):
super(SamOut, self).__init__()
self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
self.head = torch.nn.Linear(hidden_size, voc_size, False)
def forward(self, x, state=None):
x = self.em(x)
for decoder_layer in self.decoder_layers:
x,_ = decoder_layer(x)
return self.head(x)
# 测试代码
if __name__ == "__main__":
# 这里假设 DecoderLayer 已经定义好了,具体实现可以参考之前提供的代码或根据需要自定义
# 定义超参数
voc_size = 10000 # 词汇表大小
hidden_size = 512 # 隐藏层大小
num_heads = 8 # 注意力头的数量
num_layers = 6 # 解码器层数
learning_rate = 0.001
batch_size = 32
num_epochs = 10
# 初始化模型
model = SamOut(voc_size=voc_size, hidden_size=hidden_size, num_heads=num_heads, num_layers=num_layers)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=3) # 忽略填充标记的损失计算
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 模拟一些训练数据(实际应用中应该使用真实的数据集)
input_tensor = torch.randint(low=0, high=voc_size, size=(batch_size, 50)) # 输入序列长度为50
target_tensor = torch.randint(low=0, high=voc_size, size=(batch_size, 50))
# 训练循环
start_time=time.time()
for epoch in range(num_epochs):
# 前向传播
output = model(input_tensor)
# 将输出reshape以适应 CrossEntropyLoss 的输入要求
output = output.view(-1, voc_size)
target_tensor = target_tensor.view(-1)
# 计算损失
loss = criterion(output, target_tensor)
optimizer.zero_grad() # 清除梯度
# 反向传播和优化
loss.backward()
optimizer.step()
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
print("Training complete.{}".format(time.time()-start_time))
通过本文的分析,读者可以清晰理解MaxStateSuper
的设计原理及其与MaxState
的差异,为实现更高效的注意力机制提供参考。