建议先理解TwoWayAttentionBlock类,再理解TwoWayTransformer,并对照流程图看。
查询为点嵌入,键为图像嵌入。先通过两个重复的TwoWayAttentionBlock,最后通过一个token-image交叉注意力。
class TwoWayTransformer(nn.Module):
# 定义一个自定义的双向 Transformer 模块,用于处理图像和点的嵌入
def __init__(
self,
depth: int, # Transformer 的层数
embedding_dim: int, # 嵌入向量的维度
num_heads: int, # 多头注意力的头数
mlp_dim: int, # 前馈网络的隐藏层维度
activation: Type[nn.Module] = nn.ReLU, # 使用的激活函数
attention_downsample_rate: int = 2, # 注意力的下采样率
) -> None:
"""
初始化 TwoWayTransformer 模块,并根据指定配置参数构建结构。
参数:
depth (int): Transformer 的层数。
embedding_dim (int): 输入嵌入的维度。
num_heads (int): 多头注意力的头数(必须整除 embedding_dim)。
mlp_dim (int): 多层感知机(MLP)块中的隐藏层维度。
activation (nn.Module): MLP 块中使用的激活函数。
"""
super().__init__()
self.depth = depth # 保存 Transformer 的层数
self.embedding_dim = embedding_dim # 保存嵌入维度
self.num_heads = num_heads # 保存多头注意力的头数
self.mlp_dim = mlp_dim # 保存 MLP 的隐藏层维度
self.layers = nn.ModuleList() # 初始化存储 Transformer 层的列表
for i in range(depth):
# 添加 TwoWayAttentionBlock 到层列表中
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads