3.2 Transformer架构
Transformer是文生图模型架构的重要组成部分之一,具体来说,Transformer被广泛应用于文本编码部分,即实现文本编码器的功能。
3.2.1 Transformer的基本结构
Transformer架构是由Vaswani等人在2017年提出的一种基于注意力机制的深度学习模型,广泛应用于自然语言处理(NLP)和其他领域。Transformer的基本结构主要包括编码器(Encoder)和解码器(Decoder),每个部分由多个相同的层(Layers)堆叠而成。
Transformer的整体架构如下所示。
1. 编码器(Encoder)
编码器由多个相同的层(Layer)堆叠而成,每一层包括以下两个子层:
(1)多头自注意力机制(Multi-Head Self-Attention Mechanism)
- Self-Attention:计算输入序列中每个词(token)与其他词之间的相关性,为每个词生成一个加权表示。
- Multi-Head Attention:多个独立的注意力头(Attention Heads)并行工作,捕捉不同的关系模式。每个注意力头生成不同的表示,这些表示随后被连接并投影回原始维度。
(2)前馈神经网络(Feed-Forward Neural Network,FFN):由两个线性变换和一个激活函数(通常是ReLU)组成。第一个线性变换将输入映射到一个更高维度的空间,第二个线性变换将其映射回原始维度。
此外,每个子层周围都有一个残差连接(Residual Connection)和层归一化(Layer Normalization):
Input -> Multi-Head Self-Attention -> Add & Norm -> FFN -> Add & Norm -> Output
2. 解码器(Decoder)
解码器与编码器类似,也由多个相同的层堆叠而成。解码器层包括以下三个子层:
(1)多头自注意力机制(Masked Multi-Head Self-Attention Mechanism):类似于编码器的自注意力机制,但在计算注意力时对未来的词进行遮掩(Masking),确保每个位置只依赖于之前的位置。
(2)编码器-解码器注意力机制(Encoder-Decoder Attention Mechanism):计算解码器当前词与编码器输出之间的相关性,将编码器的输出信息引入解码过程。
(3)前馈神经网络(Feed-Forward Neural Network,FFN):与编码器中的FFN相同。
每个子层周围同样有残差连接和层归一化:
Input -> Masked Multi-Head Self-Attention -> Add & Norm -> Encoder-Decoder Attention -> Add & Norm -> FFN -> Add & Norm -> Output
Transformer中包含的关键组件如下所示:
- 位置编码(Positional Encoding):因为Transformer没有卷积和递归结构,所以没有直接捕捉序列顺序信息的能力。为了解决这个问题,在输入的词向量中加入位置编码,以保留序列信息。位置编码通常通过正弦和余弦函数生成。
- 多头注意力机制(Multi-Head Attention Mechanism):通过多个独立的注意力头并行工作,提高了模型捕捉不同关系模式的能力。每个注意力头执行自注意力计算,然后将结果连接起来,并通过一个线性变换投影回原始维度。
- 自注意力机制(Self-Attention Mechanism):计算输入序列中每个词与其他词之间的相关性,生成加权表示。具体计算包括生成查询(Query)、键(Key)和值(Value)向量,并通过点积计算注意力权重,最后将权重应用于值向量。
Transformer的计算步骤如下所示:
(1)输入嵌入(Input Embedding):将输入文本序列转换为嵌入向量。
(2)位置编码(Positional Encoding):将位置编码加到输入嵌入中,形成包含位置信息的输入。
(3)多头自注意力机制(Multi-Head Self-Attention Mechanism):计算查询、键和值向量,执行多头注意力计算,并将结果连接起来。
(4)前馈神经网络(Feed-Forward Neural Network,FFN):通过两个线性变换和一个激活函数处理注意力输出。
(5)残差连接和层归一化(Residual Connection and Layer Normalization):将输入直接加到子层的输出上,并进行归一化处理。
通过上述步骤,Transformer能够高效地处理序列数据,捕捉复杂的关系模式,是现代自然语言处理和其他任务中最强大的模型之一。
3.2.3 Transformer在文生图大模型中的应用
在文生图大模型(Text-to-Image Models)中,Transformer架构被广泛应用于文本编码、图像生成和多模态融合等环节。在下面的内容中,将详细讲解Transformer在文生图大模型中的具体应用和作用。
1. 文本编码(Text Encoding)
Transformer架构最初在自然语言处理(NLP)任务中表现出色,因此在文生图大模型中,Transformer通常被用作文本编码器。其主要作用是将输入的文本描述转换为高维语义向量。实现文本编码的基本流程如下所示:
(1)输入文本处理
- 将输入文本序列进行分词,并将每个词转换为词向量。
- 加入位置编码(Positional Encoding)以保留文本的顺序信息。
(2)多头自注意力机制(Multi-Head Self-Attention Mechanism):
- 通过自注意力机制计算文本中每个词与其他词之间的关系,生成加权表示。
- 多头机制捕捉不同的关系模式,提高语义表示的丰富性。
(3)前馈神经网络(Feed-Forward Neural Network,FFN):通过两个线性变换和激活函数进一步处理自注意力输出,生成最终的语义向量。
2. 多模态融合(Multimodal Fusion)
在一些高级的文生图模型中,如DALL-E和CLIP,Transformer还用于多模态融合,将文本和图像信息结合起来,从而生成更符合语义的图像。实现多模态融合的流程如下所示:
(1)文本和图像的联合表示
- 使用Transformer对文本进行编码,生成文本的语义向量。
- 使用另一个Transformer或卷积神经网络(CNN)对图像进行编码,生成图像的特征表示。
(2)融合机制
- 将文本和图像的表示结合在一起,形成联合表示。
- 可以使用跨注意力机制(Cross-Attention Mechanism)进一步处理联合表示,使文本信息和图像特征相互影响。
3. 图像生成(Image Generation)
Transformer架构也被用在图像生成任务中,尤其是在一些自回归模型中,如DALL-E。自回归模型通过逐步生成图像像素或特征,生成高质量的图像。图像生成的工作流程如下所示:
(1)文本向量作为初始输入:将经过Transformer编码的文本语义向量作为初始输入,指导图像生成过程。
(2)自回归生成过程:使用Transformer逐步生成图像的像素或特征,每一步生成当前像素或特征时,考虑之前生成的内容和文本语义向量。
(3)反向传播优化:通过反向传播和梯度下降优化生成过程,确保生成的图像与输入文本语义一致。
4. 具体应用示例:DALL-E
DALL-E是OpenAI推出的一种文生图模型,广泛使用了Transformer架构。DALL-E使用Transformer对输入的文本描述进行编码,生成语义向量。使用自回归Transformer生成图像,每一步生成图像的一个部分(如像素或块),生成过程中不断参考文本的语义向量。
请看下面的实例,实现了一个简化版的视觉Transformer模型,用于生成MNIST手写数字图像。训练完成后,使用随机噪声生成图像,并通过Matplotlib可视化生成的手写数字图像。
实例3-1:使用Transformer模型于生成MNIST手写数字图像(源码路径:codes/3/trf.py)
文件trf.py的具体实现代码如下所示。
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
# 超参数
image_size = 16 * 16 # 使用更小的图像尺寸
d_model = image_size # d_model 等于图像展平后的大小
num_classes = 10
num_epochs = 5 # 减少训练轮数以提高测试速度
batch_size = 32 # 减少批量大小
learning_rate = 0.001
# 数据集
transform = transforms.Compose([
transforms.Resize((16, 16)), # 更小的图像尺寸
transforms.ToTensor(),
])
# Vision Transformer模型
class VisionTransformer(nn.Module):
def __init__(self):
super(VisionTransformer, self).__init__()
self.transformer = nn.Transformer(d_model=d_model, nhead=4, num_encoder_layers=3, num_decoder_layers=3, batch_first=True)
self.fc_out = nn.Linear(d_model, 16 * 16) # 输出维度保持一致
def forward(self, x):
batch_size = x.size(0)
x = x.view(batch_size, -1) # Flatten图像
x = x.unsqueeze(1) # 变为 (batch_size, seq_length=1, d_model)
x = self.transformer(x, x) # 使用同一输入作为encoder和decoder
x = x.squeeze(1) # 去掉seq_length维度,变为 (batch_size, d_model)
x = self.fc_out(x) # 生成图像
x = x.view(-1, 1, 16, 16) # 恢复成图像形状
return torch.sigmoid(x) # 使用sigmoid以便输出图像
def main():
# 加载MNIST数据集
dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化模型、损失函数和优化器
model = VisionTransformer().to('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 训练模型
for epoch in range(num_epochs):
for images, _ in data_loader:
images = images.to('cuda' if torch.cuda.is_available() else 'cpu')
optimizer.zero_grad()
output = model(images)
loss = criterion(output.view(-1, 16 * 16), images.view(-1, 16 * 16).float())
loss.backward()
optimizer.step()
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
# 生成图像
with torch.no_grad():
noise = torch.randn(batch_size, 1, 16, 16).to('cuda' if torch.cuda.is_available() else 'cpu') # 随机噪声输入
generated_images = model(noise)
# 可视化生成的图像
plt.figure(figsize=(8, 8))
for i in range(16):
plt.subplot(4, 4, i + 1)
plt.imshow(generated_images[i][0].cpu(), cmap='gray')
plt.axis('off')
plt.show()
if __name__ == "__main__":
main()
上述代码的实现流程如下所示:
(1)导入库:导入必要的PyTorch、Torchvision和Matplotlib库。
(2)设置超参数:定义潜在空间的维度(latent_size)、批量大小(batch_size)、训练轮数(num_epochs)、学习率(learning_rate)和设备(GPU或CPU)。
(3)数据预处理:使用transforms.Compose对MNIST数据集中的图像进行预处理,包括调整大小、转换为张量和归一化。
(4)定义模型:创建一个名为TransformerModel的类,初始化一个Transformer模块和一个全连接层。
(5)在forward方法中,首先将输入图像展平,并通过Transformer进行处理,最后通过全连接层生成图像。
(6)初始化模型、损失函数和优化器:实例化模型、定义损失函数(如均方误差)和优化器(如Adam)。
(7)训练模型:使用数据加载器迭代训练数据。对每一批数据进行前向传播,计算损失,执行反向传播并更新模型参数。
(8)生成图像:在训练结束后,使用随机噪声作为输入通过模型生成手写数字图像。
(9)可视化生成的图像:使用Matplotlib将生成的图像显示出来,如图3-8所示。
图3-8 生成的手写数字图像