CLIP
类
以下是对 CLIP
类的详细注释。这个类实现了一个多模态模型,能够同时处理图像和文本,并生成它们的嵌入表示。CLIP 主要由视觉和文本编码器组成,并使用 Transformer 模块来处理文本。
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
vision_stride_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
h_resolution: int,
w_resolution: int):
super().__init__()
self.context_length = context_length
# 根据 vision_layers 的类型选择使用 ModifiedResNet 或 VisionTransformer 作为视觉编码器
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=h_resolution * w_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
h_resolution=h_resolution,
w_resolution=w_resolution,
patch_size=vision_patch_size,
stride_size=vision_stride_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
# 初始化 Transformer 模块作为文本编码器
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
# 初始化文本编码器的相关参数
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
# 初始化用于投影文本特征的参数
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
# 初始化 logit 的缩放参数
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
# 调用initialize_parameters函数初始化所有参数
self.initialize_parameters()
def initialize_parameters(self):
# 初始化文本嵌入层和位置嵌入层的权重
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
# 如果视觉编码器是 ModifiedResNet,初始化其权重
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
# 初始化 Transformer 模块的权重
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
# 初始化文本投影的权重
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# 构建因果注意力掩码,屏蔽未来的位置
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
# mask.triu_(1) 表示从主对角线往上数第 1 条对角线开始保留元素,主对角线及其以下的元素都被设置为零。
mask.triu_(1) # 清除下三角
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
# 编码图像
return self.visual(image.type(self.dtype))
def encode_text(self, text):
# 编码文本
x = self.token_embedding(text).type(self.dtype)
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # 转置为 (L, N, E)
x = self.transformer(x)
x = x.permute(1, 0, 2) # 转置回 (N, L, E)
x = self.ln_final(x).type(self.dtype)
# 取出 [EOS] token 的特征并进行投影
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
# 编码图像和文本
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# 归一化特征
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 计算余弦相似度作为 logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
return logits_per_image, logits_per_text
详细注释
初始化方法 __init__
-
输入参数:
embed_dim
:嵌入维度。- 视觉编码器相关参数:包括图像分辨率、视觉层数、宽度、补丁大小、步幅。
- 文本编码器相关参数:包括上下文长度、词汇表大小、Transformer 的宽度、头数量和层数。
h_resolution
和w_resolution
:图像的高度和宽度分辨率。
-
初始化视觉编码器:
- 根据
vision_layers
的类型(元组或整数)选择使用ModifiedResNet
或VisionTransformer
作为视觉编码器。 - 如果
vision_layers
是元组或列表,使用ModifiedResNet
,否则使用VisionTransformer
。 - 初始化视觉编码器
self.visual
。
- 根据
-
初始化文本编码器:
- 创建
Transformer
实例,传入 Transformer 的宽度、层数、头数量以及注意力掩码。 - 初始化文本嵌入层
self.token_embedding
和位置嵌入参数self.positional_embedding
。 - 初始化层归一化层
self.ln_final
和文本投影参数self.text_projection
。 - 初始化 logit 缩放参数
self.logit_scale
。
- 创建
-
初始化所有参数:
- 调用
self.initialize_parameters()
初始化模型的各个权重参数。
- 调用
初始化参数方法 initialize_parameters
-
初始化文本嵌入层和位置嵌入层:
nn.init.normal_(self.token_embedding.weight, std=0.02)
:用标准差为 0.02 的正态分布初始化 token 嵌入层的权重。nn.init.normal_(self.positional_embedding, std=0.01)
:用标准差为 0.01 的正态分布初始化位置嵌入层的权重。
-
初始化 ModifiedResNet 的权重(如果使用的话):
- 对
attnpool
层的投影权重进行初始化。 - 对 ResNet 块中的
bn3.weight
进行零初始化。
- 对
-
初始化 Transformer 模块的权重:
proj_std
、attn_std
和fc_std
用于设置多头注意力层和前馈网络层的标准差。- 对每个 Transformer 块的注意力投影和输出权重、MLP 的全连接层权重进行初始化。
-
初始化文本投影的权重:
- 用标准差为
transformer.width ** -0.5
的正态分布初始化self.text_projection
。
- 用标准差为
构建注意力掩码方法 build_attention_mask
- 构建因果注意力掩码:
- 创建一个形状为
(context_length, context_length)
的张量mask
,并填充为负无
- 创建一个形状为
穷大。
mask.triu_(1)
将下三角部分置为零,仅保留上三角部分的负无穷大值,从而屏蔽未来位置。
属性方法 dtype
- 返回视觉编码器权重的 dtype:
return self.visual.conv1.weight.dtype
:返回视觉编码器第一个卷积层的权重的数据类型。
编码图像方法 encode_image
- 编码图像输入:
- 将输入图像
image
转换为视觉编码器的 dtype,并传递给self.visual
进行处理。
- 将输入图像
编码文本方法 encode_text
- 编码文本输入:
- 将文本输入
text
通过token_embedding
转换为嵌入表示,并转换为视觉编码器的 dtype。 - 将位置嵌入添加到文本嵌入中。
- 转置嵌入表示,使其形状变为
(L, N, E)
。 - 通过 Transformer 处理嵌入表示。
- 再次转置回
(N, L, E)
。 - 通过层归一化处理文本嵌入。
- 获取每个样本的 [EOS] token 的特征,并进行投影。
- 将文本输入
前向传播方法 forward
-
编码图像和文本:
- 调用
encode_image
和encode_text
分别编码图像和文本。
- 调用
-
归一化特征向量:
- 对图像和文本的特征向量进行归一化,使其范数为 1。
-
计算余弦相似度:
logit_scale.exp()
计算 logit 缩放因子的指数值。- 计算图像特征和文本特征之间的余弦相似度,并乘以 logit 缩放因子,得到
logits_per_image
和logits_per_text
。
-
返回结果:
- 返回图像与文本之间的 logits 矩阵
logits_per_image
和logits_per_text
。
- 返回图像与文本之间的 logits 矩阵
总结
CLIP
类实现了一个多模态模型,可以同时处理图像和文本,生成它们的嵌入表示,并计算它们之间的相似度。模型使用视觉编码器(ModifiedResNet
或 VisionTransformer
)和文本编码器(Transformer)来处理输入,并计算图像和文本之间的相似度。这种设计使得模型可以在图像-文本匹配和检索任务中表现出色。