CNN、RNN、LSTM和Transformer架构对比与分析
一、基本架构对比
1.1 整体架构图示
二、核心特点分析
2.1 CNN特点
# CNN基本实现
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc_layers = nn.Sequential(
nn.Linear(128 * 8 * 8, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.conv_layers(x)
x = x.view(x.size(0), -1)
x = self.fc_layers(x)
return x
2.2 RNN特点
# RNN基本实现
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleRNN, self).__init__()
self.hidden_size = hidden_size
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
out, hidden = self.rnn(x, hidden)
out = self.fc(out[:, -1, :])
return out, hidden
def init_hidden(self, batch_size):
return torch.zeros(1, batch_size, self.hidden_size)
2.3 LSTM特点
# LSTM基本实现
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
out, (hn, cn) = self.lstm(x, (h0, c0))
out = self.fc(out[:, -1, :])
return out
2.4 Transformer特点
# Transformer注意力机制实现
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention = F.softmax(scores, dim=-1)
return torch.matmul(attention, V)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
Q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
x = self.scaled_dot_product_attention(Q, K, V, mask)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.w_o(x)
三、应用场景对比
3.1 适用任务类型
网络类型 | 图像处理 | 序列处理 | 长序列依赖 | 并行计算 |
---|---|---|---|---|
CNN | ★★★★★ | ★★ | ★ | ★★★★ |
RNN | ★ | ★★★★ | ★★ | ★ |
LSTM | ★ | ★★★★★ | ★★★★ | ★★ |
Transformer | ★★ | ★★★★★ | ★★★★★ | ★★★★★ |
3.2 典型应用示例
# 多模型组合应用示例
class HybridModel(nn.Module):
def __init__(self):
super(HybridModel, self).__init__()
# CNN用于特征提取
self.cnn = CNN()
# LSTM用于序列建模
self.lstm = LSTM(input_size=128, hidden_size=256, num_layers=2, output_size=64)
# Transformer用于关系建模
self.transformer = TransformerLayer(d_model=64, num_heads=8)
def forward(self, image, sequence):
# 1. 图像特征提取
image_features = self.cnn(image)
# 2. 序列处理
sequence_features = self.lstm(sequence)
# 3. 特征融合和关系建模
combined_features = torch.cat([image_features, sequence_features], dim=1)
output = self.transformer(combined_features)
return output
四、性能与效率对比
4.1 计算复杂度
graph LR
subgraph 计算复杂度
A[CNN] -->|O(k*n)| B[RNN]
B -->|O(n)| C[LSTM]
C -->|O(n^2)| D[Transformer]
end
4.2 内存消耗
各模型的内存使用特点:
- CNN:与输入尺寸和层数相关
- RNN:与序列长度线性相关
- LSTM:比RNN需要更多内存
- Transformer:与序列长度的平方相关
五、优缺点总结
5.1 CNN
优点:
- 局部特征提取能力强
- 参数共享,效率高
- 适合处理网格结构数据
缺点:
- 全局感受野受限
- 不适合处理序列数据
- 空间信息可能丢失
5.2 RNN
优点:
- 适合处理序列数据
- 参数量相对较少
- 可处理变长序列
缺点:
- 训练困难(梯度消失/爆炸)
- 长序列记忆能力差
- 并行度低
5.3 LSTM
优点:
- 解决了RNN的梯度问题
- 更好的长期依赖能力
- 信息流控制更灵活
缺点:
- 计算复杂度高
- 训练时间长
- 并行化程度低
5.4 Transformer
优点:
- 并行计算能力强
- 可以处理长距离依赖
- 注意力机制效果好
缺点:
- 计算复杂度高
- 需要大量训练数据
- 位置编码可能不够理想
六、发展趋势
6.1 混合架构
6.2 未来方向
- 模型压缩和加速
- 自适应架构设计
- 可解释性增强
- 迁移学习能力提升