lib/python3.8/site-packages/clip/model.py#L206
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
# x: 输入原始图像,经过缩放,统一的大小为 224*224
# 一幅输入224 x 224的图像,首先经过卷积处理得到16 x 16个patch,那么每一个patch的大小就是14 x 14
# 将每一个patch的矩阵拉伸成为一个1维向量,从而获得了近似词向量堆叠的效果。上一步得道的14 x 14的patch就转换为长度为196的向量
x = self.conv1(x) # shape = [*, width, grid, grid]
# 每个patch拉伸为1*196
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# 加上class embedding变为1*197的向量
# class_embedding主要借鉴了BERT模型的用于文本分类时的思想,在每一个word vector之前增加一个类别值,通常是加在向量的第一位,上一步得到的196维的向量加上class_embedding后变为197维。
# 增加的class_embedding是一个可以学习的参数,经过网络的不断训练,最终以输出向量的第一个维度的输出来决定最后的输出类别;由于输入是16 x 16个patch,所以输出进行分类时是取 16 x 16个class_embedding进行分类。
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
# 加上1*197的position embedding
# pos_embedding也是一组可以学习的参数,会被加入到经过处理的patch矩阵中
# 它的加入类似于全链接网络和卷积的bias
x = x + self.positional_embedding.to(x.dtype)
# pre layer norm
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
# post layer norm
x = self.ln_post(x[:, 0, :])
# 由于增加的class_embedding是一个可以学习的参数,经过网络的不断训练
# 最终以输出向量的第一个维度的输出来决定最后的输出类别
# [bs, n_patch=257, dim=1024] -> [bs, dim=1024]
if self.proj is not None:
x = x @ self.proj
return x