CLIP模型

简介

CLIP,全称Constrastive Language-Image Pre-training,是2021年OpenAI提出的基于4亿(400M)图像-文本对数据,通过对比学习的多模态预训练模型,具备强大的zero-shot迁移能力。

论文地址:https://arxiv.org/pdf/2103.00020.pdf
代码地址:https://github.com/openai/CLIP

模型结构图:

Panda

CLIP代码结构

伪代码如下:

Panda

对伪代码进行注释:

# image_encoder    图像编码器:ResNet或者ViT
# text_encoder     文本编码器:CBOW或者Text Transformer
# I[n,h,w,c]       图像输入大小: 比如 [16, 224, 224, 3]
# T[n,l]           文本输入大小:n表示batch size,l表示序列长度
# W_i[d_i, d_e]     图像的投射层,学习如何从单模态到多模态
# W_t[d_t, d_e]     文本的投射层,学习如何从单模态到多模态
# t                可学习的温度系数

# 分别提取每个模态的特征
I_f = image_encoder(I)  # 输出大小 [n, d_i]
T_f = text_encoder(T)   # 输出大小 [n, d_t]

# 合并多模态特征
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)  # 输出大小 [n, d_e]
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)  # 输出大小 [n, d_e]

# 计算cos相似度,I_e 与 T_e 转置矩阵的点积,np.exp(t) 表示可学习的温度系数的指数。
# 这里实际是在对点积的结果进行加权,通过指数函数引入温度系数的影响,以调整余弦相似度的分数。
# 这样的加权余弦相似度通常用于度量多模态特征之间的相似性。
logits = np.dot(I_e, T_e.T) * np.exp(t)  # 输出大小 [n, n]

# 计算损失函数
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)  # 针对图像的交叉熵损失
loss_t = cross_entropy_loss(logits, labels, axis=1)  # 针对文本的交叉熵损失
loss = (loss_i + loss_t) / 2  # 综合图像和文本的损失

对温度系数t的理解:
clip在处理图像和文本这两个向量的相似度时,在0~1区间范围内的辨识度很低,也就是说,难以拉开两个相似度不高的向量之间的距离,而温度系数t可以把向量的相似度进行放大,使其能够很好的区分。

温度系数来源于InfoNCE损失函数中:
I n f o N C E L o s s = − 1 N ∑ i = 1 N l o g ( exp ⁡ ( q i ⋅ k i +   τ ) ∑ j = 1 N exp ⁡ ( q i ⋅ k i − τ ) ) InfoNCE Loss = -\frac{1}{N} \sum_{i=1}^{N} log\left(\frac{\exp\left(\frac{q_i \cdot k_{i+}\ }{\tau}\right)}{\sum_{j=1}^{N} \exp\left(\frac{q_i \cdot k_{i-}}{\tau}\right)}\right) InfoNCELoss=N1i=1Nlog j=1Nexp(τqiki)exp(τqiki+ )

# 温度系数
self.T = 0.07

# 计算相似度
# positive logits: Nx1
s_pos = torch.sum(q*k, dim=1).unsqueeze(dim=1)
# negative logits: NxK
s_neg = torch.matmul(q, self.queue.clone().detach().T)

# 拼接相似度 logits: Nx(1+K)
logits = torch.cat([s_pos, s_neg], dim=1)
logits /= self.T

# 创建标签
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

# 计算InfoNCE损失
loss = F.cross_entropy(logits, labels)

CLIP前向架构

    def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text

CLIP图像编码模块

图像预处理

def _transform(input_resolution):
    return Compose([
        Resize(input_resolution, interpolation=BICUBIC),
        CenterCrop(input_resolution),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

encode_image

该模块提供了两种编码的backbone,分别是经过修改的ResNet和ViT。
ModifiedResNet主干:

class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """

    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
        super().__init__()
        self.output_dim = output_dim
        self.input_resolution = input_resolution

        # the 3-layer stem
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.relu3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(2)

        # residual layers
        self._inplanes = width  # this is a *mutable* variable used during construction
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

        embed_dim = width * 32  # the ResNet feature dimension
        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)

    def _make_layer(self, planes, blocks, stride=1):
        layers = [Bottleneck(self._inplanes, planes, stride)]

        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        def stem(x):
            x = self.relu1(self.bn1(self.conv1(x)))
            x = self.relu2(self.bn2(self.conv2(x)))
            x = self.relu3(self.bn3(self.conv3(x)))
            x = self.avgpool(x)
            return x

        x = x.type(self.conv1.weight.dtype)
        x = stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.attnpool(x)

        return x

Transformer主干:

class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x

CLIP文本编码模块

def encode_text(self, text):
    x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
    x = x + self.positional_embedding.type(self.dtype)
    x = x.permute(1, 0, 2)  # NLD -> LND
    x = self.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = self.ln_final(x).type(self.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

    return x

CLIP模型参数

ViT-B-32模型image_encoder模块参数量:

Total params: 59,068,416
Trainable params: 59,068,416
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 349.28
Params size (MB): 225.33
Estimated Total Size (MB): 575.18
----------------------------------------------------------------
        Layer (type)               Output Shape         Param # 
================================================================
            Conv2d-1            [-1, 768, 7, 7]       2,359,296
         LayerNorm-2              [-1, 50, 768]           1,536
         LayerNorm-3               [-1, 2, 768]           1,536
MultiheadAttention-4  [[-1, 2, 768], [-1, 50, 50]]               0
         LayerNorm-5               [-1, 2, 768]           1,536
            Linear-6              [-1, 2, 3072]       2,362,368
         QuickGELU-7              [-1, 2, 3072]               0
            Linear-8               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-9               [-1, 2, 768]               0
        LayerNorm-10               [-1, 2, 768]           1,536
MultiheadAttention-11  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-12               [-1, 2, 768]           1,536
           Linear-13              [-1, 2, 3072]       2,362,368
        QuickGELU-14              [-1, 2, 3072]               0
           Linear-15               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-16               [-1, 2, 768]               0
        LayerNorm-17               [-1, 2, 768]           1,536
MultiheadAttention-18  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-19               [-1, 2, 768]           1,536
           Linear-20              [-1, 2, 3072]       2,362,368
        QuickGELU-21              [-1, 2, 3072]               0
           Linear-22               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-23               [-1, 2, 768]               0
        LayerNorm-24               [-1, 2, 768]           1,536
MultiheadAttention-25  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-26               [-1, 2, 768]           1,536
           Linear-27              [-1, 2, 3072]       2,362,368
        QuickGELU-28              [-1, 2, 3072]               0
           Linear-29               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-30               [-1, 2, 768]               0
        LayerNorm-31               [-1, 2, 768]           1,536
MultiheadAttention-32  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-33               [-1, 2, 768]           1,536
           Linear-34              [-1, 2, 3072]       2,362,368
        QuickGELU-35              [-1, 2, 3072]               0
           Linear-36               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-37               [-1, 2, 768]               0
        LayerNorm-38               [-1, 2, 768]           1,536
MultiheadAttention-39  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-40               [-1, 2, 768]           1,536
           Linear-41              [-1, 2, 3072]       2,362,368
        QuickGELU-42              [-1, 2, 3072]               0
           Linear-43               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-44               [-1, 2, 768]               0
        LayerNorm-45               [-1, 2, 768]           1,536
MultiheadAttention-46  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-47               [-1, 2, 768]           1,536
           Linear-48              [-1, 2, 3072]       2,362,368
        QuickGELU-49              [-1, 2, 3072]               0
           Linear-50               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-51               [-1, 2, 768]               0
        LayerNorm-52               [-1, 2, 768]           1,536
MultiheadAttention-53  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-54               [-1, 2, 768]           1,536
           Linear-55              [-1, 2, 3072]       2,362,368
        QuickGELU-56              [-1, 2, 3072]               0
           Linear-57               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-58               [-1, 2, 768]               0
        LayerNorm-59               [-1, 2, 768]           1,536
MultiheadAttention-60  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-61               [-1, 2, 768]           1,536
           Linear-62              [-1, 2, 3072]       2,362,368
        QuickGELU-63              [-1, 2, 3072]               0
           Linear-64               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-65               [-1, 2, 768]               0
        LayerNorm-66               [-1, 2, 768]           1,536
MultiheadAttention-67  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-68               [-1, 2, 768]           1,536
           Linear-69              [-1, 2, 3072]       2,362,368
        QuickGELU-70              [-1, 2, 3072]               0
           Linear-71               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-72               [-1, 2, 768]               0
        LayerNorm-73               [-1, 2, 768]           1,536
MultiheadAttention-74  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-75               [-1, 2, 768]           1,536
           Linear-76              [-1, 2, 3072]       2,362,368
        QuickGELU-77              [-1, 2, 3072]               0
           Linear-78               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-79               [-1, 2, 768]               0
        LayerNorm-80               [-1, 2, 768]           1,536
MultiheadAttention-81  [[-1, 2, 768], [-1, 50, 50]]               0
        LayerNorm-82               [-1, 2, 768]           1,536
           Linear-83              [-1, 2, 3072]       2,362,368
        QuickGELU-84              [-1, 2, 3072]               0
           Linear-85               [-1, 2, 768]       2,360,064
ResidualAttentionBlock-86               [-1, 2, 768]               0
      Transformer-87               [-1, 2, 768]               0
        LayerNorm-88                  [-1, 768]           1,536
================================================================

RN50模型image_encoder模块参数量:

Total params: 23,527,264
Trainable params: 23,527,264
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 331.71
Params size (MB): 89.75
Estimated Total Size (MB): 422.04
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
              ReLU-3         [-1, 32, 112, 112]               0
            Conv2d-4         [-1, 32, 112, 112]           9,216
       BatchNorm2d-5         [-1, 32, 112, 112]              64
              ReLU-6         [-1, 32, 112, 112]               0
            Conv2d-7         [-1, 64, 112, 112]          18,432
       BatchNorm2d-8         [-1, 64, 112, 112]             128
              ReLU-9         [-1, 64, 112, 112]               0
        AvgPool2d-10           [-1, 64, 56, 56]               0
           Conv2d-11           [-1, 64, 56, 56]           4,096
      BatchNorm2d-12           [-1, 64, 56, 56]             128
             ReLU-13           [-1, 64, 56, 56]               0
           Conv2d-14           [-1, 64, 56, 56]          36,864
      BatchNorm2d-15           [-1, 64, 56, 56]             128
             ReLU-16           [-1, 64, 56, 56]               0
         Identity-17           [-1, 64, 56, 56]               0
           Conv2d-18          [-1, 256, 56, 56]          16,384
      BatchNorm2d-19          [-1, 256, 56, 56]             512
        AvgPool2d-20           [-1, 64, 56, 56]               0
           Conv2d-21          [-1, 256, 56, 56]          16,384
      BatchNorm2d-22          [-1, 256, 56, 56]             512
             ReLU-23          [-1, 256, 56, 56]               0
       Bottleneck-24          [-1, 256, 56, 56]               0
           Conv2d-25           [-1, 64, 56, 56]          16,384
      BatchNorm2d-26           [-1, 64, 56, 56]             128
             ReLU-27           [-1, 64, 56, 56]               0
           Conv2d-28           [-1, 64, 56, 56]          36,864
      BatchNorm2d-29           [-1, 64, 56, 56]             128
             ReLU-30           [-1, 64, 56, 56]               0
         Identity-31           [-1, 64, 56, 56]               0
           Conv2d-32          [-1, 256, 56, 56]          16,384
      BatchNorm2d-33          [-1, 256, 56, 56]             512
             ReLU-34          [-1, 256, 56, 56]               0
       Bottleneck-35          [-1, 256, 56, 56]               0
           Conv2d-36           [-1, 64, 56, 56]          16,384
      BatchNorm2d-37           [-1, 64, 56, 56]             128
             ReLU-38           [-1, 64, 56, 56]               0
           Conv2d-39           [-1, 64, 56, 56]          36,864
      BatchNorm2d-40           [-1, 64, 56, 56]             128
             ReLU-41           [-1, 64, 56, 56]               0
         Identity-42           [-1, 64, 56, 56]               0
           Conv2d-43          [-1, 256, 56, 56]          16,384
      BatchNorm2d-44          [-1, 256, 56, 56]             512
             ReLU-45          [-1, 256, 56, 56]               0
       Bottleneck-46          [-1, 256, 56, 56]               0
           Conv2d-47          [-1, 128, 56, 56]          32,768
      BatchNorm2d-48          [-1, 128, 56, 56]             256
             ReLU-49          [-1, 128, 56, 56]               0
           Conv2d-50          [-1, 128, 56, 56]         147,456
      BatchNorm2d-51          [-1, 128, 56, 56]             256
             ReLU-52          [-1, 128, 56, 56]               0
        AvgPool2d-53          [-1, 128, 28, 28]               0
           Conv2d-54          [-1, 512, 28, 28]          65,536
      BatchNorm2d-55          [-1, 512, 28, 28]           1,024
        AvgPool2d-56          [-1, 256, 28, 28]               0
           Conv2d-57          [-1, 512, 28, 28]         131,072
      BatchNorm2d-58          [-1, 512, 28, 28]           1,024
             ReLU-59          [-1, 512, 28, 28]               0
       Bottleneck-60          [-1, 512, 28, 28]               0
           Conv2d-61          [-1, 128, 28, 28]          65,536
      BatchNorm2d-62          [-1, 128, 28, 28]             256
             ReLU-63          [-1, 128, 28, 28]               0
           Conv2d-64          [-1, 128, 28, 28]         147,456
      BatchNorm2d-65          [-1, 128, 28, 28]             256
             ReLU-66          [-1, 128, 28, 28]               0
         Identity-67          [-1, 128, 28, 28]               0
           Conv2d-68          [-1, 512, 28, 28]          65,536
      BatchNorm2d-69          [-1, 512, 28, 28]           1,024
             ReLU-70          [-1, 512, 28, 28]               0
       Bottleneck-71          [-1, 512, 28, 28]               0
           Conv2d-72          [-1, 128, 28, 28]          65,536
      BatchNorm2d-73          [-1, 128, 28, 28]             256
             ReLU-74          [-1, 128, 28, 28]               0
           Conv2d-75          [-1, 128, 28, 28]         147,456
      BatchNorm2d-76          [-1, 128, 28, 28]             256
             ReLU-77          [-1, 128, 28, 28]               0
         Identity-78          [-1, 128, 28, 28]               0
           Conv2d-79          [-1, 512, 28, 28]          65,536
      BatchNorm2d-80          [-1, 512, 28, 28]           1,024
             ReLU-81          [-1, 512, 28, 28]               0
       Bottleneck-82          [-1, 512, 28, 28]               0
           Conv2d-83          [-1, 128, 28, 28]          65,536
      BatchNorm2d-84          [-1, 128, 28, 28]             256
             ReLU-85          [-1, 128, 28, 28]               0
           Conv2d-86          [-1, 128, 28, 28]         147,456
      BatchNorm2d-87          [-1, 128, 28, 28]             256
             ReLU-88          [-1, 128, 28, 28]               0
         Identity-89          [-1, 128, 28, 28]               0
           Conv2d-90          [-1, 512, 28, 28]          65,536
      BatchNorm2d-91          [-1, 512, 28, 28]           1,024
             ReLU-92          [-1, 512, 28, 28]               0
       Bottleneck-93          [-1, 512, 28, 28]               0
           Conv2d-94          [-1, 256, 28, 28]         131,072
      BatchNorm2d-95          [-1, 256, 28, 28]             512
             ReLU-96          [-1, 256, 28, 28]               0
           Conv2d-97          [-1, 256, 28, 28]         589,824
      BatchNorm2d-98          [-1, 256, 28, 28]             512
             ReLU-99          [-1, 256, 28, 28]               0
       AvgPool2d-100          [-1, 256, 14, 14]               0
          Conv2d-101         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-102         [-1, 1024, 14, 14]           2,048
       AvgPool2d-103          [-1, 512, 14, 14]               0
          Conv2d-104         [-1, 1024, 14, 14]         524,288
     BatchNorm2d-105         [-1, 1024, 14, 14]           2,048
            ReLU-106         [-1, 1024, 14, 14]               0
      Bottleneck-107         [-1, 1024, 14, 14]               0
          Conv2d-108          [-1, 256, 14, 14]         262,144
     BatchNorm2d-109          [-1, 256, 14, 14]             512
            ReLU-110          [-1, 256, 14, 14]               0
          Conv2d-111          [-1, 256, 14, 14]         589,824
     BatchNorm2d-112          [-1, 256, 14, 14]             512
            ReLU-113          [-1, 256, 14, 14]               0
        Identity-114          [-1, 256, 14, 14]               0
          Conv2d-115         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-116         [-1, 1024, 14, 14]           2,048
            ReLU-117         [-1, 1024, 14, 14]               0
      Bottleneck-118         [-1, 1024, 14, 14]               0
          Conv2d-119          [-1, 256, 14, 14]         262,144
     BatchNorm2d-120          [-1, 256, 14, 14]             512
            ReLU-121          [-1, 256, 14, 14]               0
          Conv2d-122          [-1, 256, 14, 14]         589,824
     BatchNorm2d-123          [-1, 256, 14, 14]             512
            ReLU-124          [-1, 256, 14, 14]               0
        Identity-125          [-1, 256, 14, 14]               0
          Conv2d-126         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-127         [-1, 1024, 14, 14]           2,048
            ReLU-128         [-1, 1024, 14, 14]               0
      Bottleneck-129         [-1, 1024, 14, 14]               0
          Conv2d-130          [-1, 256, 14, 14]         262,144
     BatchNorm2d-131          [-1, 256, 14, 14]             512
            ReLU-132          [-1, 256, 14, 14]               0
          Conv2d-133          [-1, 256, 14, 14]         589,824
     BatchNorm2d-134          [-1, 256, 14, 14]             512
            ReLU-135          [-1, 256, 14, 14]               0
        Identity-136          [-1, 256, 14, 14]               0
          Conv2d-137         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-138         [-1, 1024, 14, 14]           2,048
            ReLU-139         [-1, 1024, 14, 14]               0
      Bottleneck-140         [-1, 1024, 14, 14]               0
          Conv2d-141          [-1, 256, 14, 14]         262,144
     BatchNorm2d-142          [-1, 256, 14, 14]             512
            ReLU-143          [-1, 256, 14, 14]               0
          Conv2d-144          [-1, 256, 14, 14]         589,824
     BatchNorm2d-145          [-1, 256, 14, 14]             512
            ReLU-146          [-1, 256, 14, 14]               0
        Identity-147          [-1, 256, 14, 14]               0
          Conv2d-148         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-149         [-1, 1024, 14, 14]           2,048
            ReLU-150         [-1, 1024, 14, 14]               0
      Bottleneck-151         [-1, 1024, 14, 14]               0
          Conv2d-152          [-1, 256, 14, 14]         262,144
     BatchNorm2d-153          [-1, 256, 14, 14]             512
            ReLU-154          [-1, 256, 14, 14]               0
          Conv2d-155          [-1, 256, 14, 14]         589,824
     BatchNorm2d-156          [-1, 256, 14, 14]             512
            ReLU-157          [-1, 256, 14, 14]               0
        Identity-158          [-1, 256, 14, 14]               0
          Conv2d-159         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-160         [-1, 1024, 14, 14]           2,048
            ReLU-161         [-1, 1024, 14, 14]               0
      Bottleneck-162         [-1, 1024, 14, 14]               0
          Conv2d-163          [-1, 512, 14, 14]         524,288
     BatchNorm2d-164          [-1, 512, 14, 14]           1,024
            ReLU-165          [-1, 512, 14, 14]               0
          Conv2d-166          [-1, 512, 14, 14]       2,359,296
     BatchNorm2d-167          [-1, 512, 14, 14]           1,024
            ReLU-168          [-1, 512, 14, 14]               0
       AvgPool2d-169            [-1, 512, 7, 7]               0
          Conv2d-170           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-171           [-1, 2048, 7, 7]           4,096
       AvgPool2d-172           [-1, 1024, 7, 7]               0
          Conv2d-173           [-1, 2048, 7, 7]       2,097,152
     BatchNorm2d-174           [-1, 2048, 7, 7]           4,096
            ReLU-175           [-1, 2048, 7, 7]               0
      Bottleneck-176           [-1, 2048, 7, 7]               0
          Conv2d-177            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-178            [-1, 512, 7, 7]           1,024
            ReLU-179            [-1, 512, 7, 7]               0
          Conv2d-180            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-181            [-1, 512, 7, 7]           1,024
            ReLU-182            [-1, 512, 7, 7]               0
        Identity-183            [-1, 512, 7, 7]               0
          Conv2d-184           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-185           [-1, 2048, 7, 7]           4,096
            ReLU-186           [-1, 2048, 7, 7]               0
      Bottleneck-187           [-1, 2048, 7, 7]               0
          Conv2d-188            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-189            [-1, 512, 7, 7]           1,024
            ReLU-190            [-1, 512, 7, 7]               0
          Conv2d-191            [-1, 512, 7, 7]       2,359,296
     BatchNorm2d-192            [-1, 512, 7, 7]           1,024
            ReLU-193            [-1, 512, 7, 7]               0
        Identity-194            [-1, 512, 7, 7]               0
          Conv2d-195           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-196           [-1, 2048, 7, 7]           4,096
            ReLU-197           [-1, 2048, 7, 7]               0
      Bottleneck-198           [-1, 2048, 7, 7]               0
 AttentionPool2d-199                 [-1, 1024]               0
================================================================

text_encoder网络结构打印:

positional_embedding torch.Size([77, 512])
text_projection torch.Size([512, 1024])
logit_scale torch.Size([])
token_embedding.weight torch.Size([49408, 512])
ln_final.weight torch.Size([512])
ln_final.bias torch.Size([512])

CLIP支持模型

‘RN50’, ‘RN101’, ‘RN50x4’, ‘RN50x16’, ‘RN50x64’, ‘ViT-B/32’, ‘ViT-B/16’, ‘ViT-L/14’, ‘ViT-L/14@336px’

各模型性能对比:

Panda

CLIP模型下载地址

模型下载地址
RN50https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt
RN101https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt
RN50X4https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt
RN50X16https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt
RN50X64https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt
ViT-B/32https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
ViT-B/16https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt
ViT-L/14https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt
ViT-L/14@336pxhttps://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt

各类模型参数对比

modelimage_sizeimage_widthtext_widthembed_dimmparamsimage_mparamstext_mparamsgflopsimage_gflopstext_gflops
ViT-S-32-alt22438425625643.2222.5920.633.562.291.27
ViT-S-3222438438438463.0922.6440.445.662.293.38
ViT-M-32-alt22451238438480.0739.6340.447.373.993.38
ViT-M-32224512512512103.1239.6963.439.953.995.96
ViT-S-16-alt22438425625642.421.7620.6310.479.21.27
ViT-S-1622438438438462.2621.8140.4412.589.23.38
ViT-B-32224768512512151.2887.8563.4314.788.825.96
ViT-B-32-quickgelu224768512512151.2887.8563.4314.788.825.96
convnext_tiny224768512102492.328.6163.6914.878.915.96
ViT-B-32-256256768512512151.2987.8663.4317.4611.55.96
RN50224645121024102.0138.3263.6918.1812.225.96
RN50-quickgelu224645121024102.0138.3263.6918.1812.225.96
ViT-M-16-alt22451238438478.9838.5340.4419.3615.983.38
ViT-M-16224512512512102.0238.5963.4321.9415.985.96
vit_relpos_medium_patch16_cls_224224768512512101.9438.5163.4321.9916.035.96
mt5-base-ViT-B-32224768512512365.7187.85277.8622.128.8213.3
convnext_small224768512512113.2849.8563.4323.3317.375.96
ViT-B-32-plus-256256896640640210.3119.1391.1624.8315.569.27
RN10122464512512119.6956.2663.4325.519.545.96
RN101-quickgelu22464512512119.6956.2663.4325.519.545.96
vit_medium_patch16_gap_256256768512512102.0438.6163.4327.121.145.96
coca_ViT-B-32224768512512253.5689.1663.4333.349.195.96
convnext_base224768512512151.5288.0963.4336.6730.715.96
swin_base_patch4_window7_224224768640640178.5687.491.1640.1330.869.27
ViT-B-16224768512512149.6286.1963.4341.0935.135.96
ViT-B-16-quickgelu224768512512149.6286.1963.4341.0935.135.96
EVA02-B-16224768512512149.6986.2663.4341.0935.135.96
ViT-B-16-SigLIP224768768768203.1692.88110.2746.4435.4211.02
convnext_base_w256768640640179.3988.2291.1649.3840.119.27
RN50x428880640640178.387.1491.1651.8242.569.27
coca_roberta-ViT-B-32224768768512420.3787.85124.4553.128.8213.12
ViT-B-16-plus224896640640208.35117.1991.1656.7547.499.27
ViT-B-16-SigLIP-256256768768768203.292.93110.2757.8446.8211.02
ViT-B-16-SigLIP-i18n-256256768768768370.6392.93277.757.8446.8211.02
ViT-B-16-plus-240240896640640208.38117.2191.1664.0354.769.27
convnext_base_w_320320768640640179.3988.2291.1671.9462.679.27
convnext_large224768768768321.06197.41123.6582.0268.7213.3
coca_base288768768512440.3486.4134.6699.0946.4713.3
roberta-ViT-B-32224768512512212.7287.85124.87105.878.8297.05
xlm-roberta-base-ViT-B-32224768512512366.1287.85278.27105.878.8297.05
convnext_large_d256768768768351.77199.77152.0107.589.7617.73
ViT-B-16-SigLIP-384384768768768203.4593.18110.27123.15112.1311.02
ViT-L-162241024768768427.74304.09123.65136.41123.1113.3
convnext_large_d_320320768768768351.77199.77152.0157.98140.2517.73
RN50x1638496768768290.98167.33123.65162.69149.3913.3
ViT-L-14-CLIPA2241024768768414.21303.96110.25167.5162.035.47
EVA02-L-14224768768768427.76304.11123.65175.3162.013.3
ViT-L-142241024768768427.62303.97123.65175.33162.0313.3
ViT-L-14-quickgelu2241024768768427.62303.97123.65175.33162.0313.3
convnext_xlarge25676810241024653.89350.25303.65198.38159.1439.24
ViT-L-16-SigLIP-25625676810241024652.15315.96336.19201.62162.5639.06
coca_ViT-L-142241024768768638.45306.72123.65214.52163.6413.3
ViT-B-16-SigLIP-512512768768768203.7993.52110.27227.26216.2411.02
ViT-SO400M-14-SigLIP22476811521152877.36427.68449.68233.54220.3513.19
ViT-L-14-2802801024768768427.76304.11123.65271.79258.4913.3
ViT-L-16-3203201024768768427.95304.3123.65271.93258.6313.3
ViT-H-16224128010241024986.26632.23354.03301.72254.6347.09
ViT-H-14-CLIPA224128010241024968.24632.07336.16354.02334.5919.43
nllb-clip-base224768512512501.8987.85414.04369.68.82360.78
ViT-H-14224128010241024986.11632.08354.03381.68334.5947.09
ViT-H-14-quickgelu224128010241024986.11632.08354.03381.68334.5947.09
ViT-L-14-CLIPA-3363361024768768414.54304.29110.25387.39381.925.47
EVA02-L-14-336336768768768428.08304.43123.65395.16381.8613.3
ViT-L-14-3363361024768768427.94304.29123.65395.22381.9213.3
ViT-L-16-SigLIP-38438476810241024652.48316.28336.19422.91383.8539.06
convnext_xxlarge256768102410241200.58846.54354.03443.03395.9447.09
nllb-clip-base-siglip384768512768507.4793.18414.3472.91112.13360.78
mt5-xl-ViT-H-14224128051210242306.75632.081674.68514.04334.59179.45
EVA01-g-1422476876810241136.441012.59123.85547.36534.0613.3
RN50x6444812810241024623.26420.38202.88552.65529.1123.55
EVA01-g-14-plus224768102410241366.621012.59354.03581.15534.0647.09
ViT-g-142241408102410241366.681012.65354.03581.15534.0647.09
convnext_xxlarge_320320768102410241200.58846.54354.03665.74618.6547.09
xlm-roberta-large-ViT-H-14224128051210241193.01632.08560.94671.01334.59336.42
ViT-SO400M-14-SigLIP-38438476811521152877.96428.23449.73723.48670.3553.13
ViT-H-14-CLIPA-336336128010241024968.64632.48336.16800.88781.4519.43
ViT-bigG-14-CLIPA2241664128012802517.221844.9672.321007.93967.540.44
ViT-H-14-378-quickgelu378128010241024986.71632.68354.031054.051006.9647.09
ViT-bigG-142241664128012802539.571844.91694.661065.36967.597.86
nllb-clip-large224128051210241399.22632.08767.141468.46334.591133.87
nllb-clip-large-siglip38476851211521195.5428.23767.271804.22670.351133.87
ViT-e-142241792128012804581.093807.72773.372091.451981.35110.1
ViT-bigG-14-CLIPA-3363361664128012802517.761845.44672.322271.582231.1540.44
EVA02-E-14224768102410244704.594350.56354.032311.422264.3347.09
EVA02-E-14-plus224768128010245044.894350.56694.332362.192264.3397.86
<think>嗯,用户想了解如何微调CLIP模型CLIP是OpenAI的那个多模态模型,对吧?它能把图像和文本映射到同一个向量空间,所以在很多任务上表现不错。但是用户可能需要在自己的数据集上调整它,以适应特定的应用场景。首先,我得回忆一下CLIP的基本结构。CLIP有两个编码器,一个处理图像,一个处理文本,然后通过对比学习来对齐特征。 那么微调的话,通常需要调整哪些部分呢?是只调后面的层,还是整个模型都参与训练?可能要看数据量的大小。如果数据量少的话,可能只微调顶部的几层,或者添加适配器模块。但如果数据足够,也许整个模型都可以微调。不过CLIP本身已经在大规模数据集上预训练过,所以可能不需要太多调整。 接下来,数据准备方面。用户需要配对的图像和文本数据,对吧?比如,每张图片对应一个描述。数据增强的话,可能需要对图像进行裁剪、翻转、颜色变换,文本部分可能需要随机掩码或者替换同义词? 然后是损失函数,CLIP原本用的是对比损失,计算图像和文本的相似度矩阵,然后最大化正对的相似度,最小化负对的。微调的时候应该还是用同样的损失函数,但可能需要调整温度参数或者其他超参数? 框架选择的话,Hugging Face的Transformers库现在应该支持CLIP模型,可以直接加载预训练的权重。或者用PyTorch Lightning来简化训练流程。分布式训练的话,如果数据量大,可能需要用多GPU或者TPU。 具体的步骤可能包括:加载预训练模型,准备数据集,定义数据加载器,设置优化器(比如AdamW),调整学习率,可能还需要学习率调度器。冻结部分层的话,比如冻结图像编码器,只训练文本编码器,或者反过来,取决于任务需求。 评估方法的话,可能要在下游任务上测试,比如图像分类,用零样本或者线性探测的方式。或者用检索任务,计算召回率等指标。 需要注意的问题,比如过拟合,可能需要早停、权重衰减或者Dropout。还有计算资源的问题,微调CLIP可能需要较大的显存,尤其是如果保持高分辨率输入的话。混合精度训练可以节省显存和加快速度。 有没有现成的代码示例?比如Hugging Face的文档里可能有微调的教程,或者GitHub上的项目。需要提醒用户参考这些资源,调整参数适应自己的数据。 哦,对了,还有学习率的设置,预训练模型通常需要较小的学习率,比如1e-5到1e-4之间。还有,如果用户的数据集和原始CLIP训练的数据分布差异大,可能需要更多的微调层数。 可能用户的应用场景是什么?比如商品图片和描述的对齐,或者医疗图像的特殊标注。不同的场景可能需要不同的处理方式,比如在医疗领域,文本描述可能更专业,需要调整tokenizer或者加入领域特定的词汇。 总之,步骤大概是:准备数据、选择模型部分、设置训练参数、训练、评估、调优。需要确保代码实现正确,比如正确计算对比损失,处理数据时对齐图像和文本的批次。</think>以下是微调CLIP模型的步骤详解,采用结构化表述: $$ \text{CLIP微调} = f(\text{数据准备}, \text{模型调整}, \text{训练策略}) $$ ### 一、核心原理 CLIP(Contrastive Language-Image Pretraining)通过对比学习对齐图文特征: $$ \text{相似度矩阵} = \text{ImageEncoder}(I) \cdot \text{TextEncoder}(T)^\top $$ 目标函数为对称交叉熵损失: $$ \mathcal{L} = \frac{1}{2}(\mathcal{L}_{\text{img2txt}} + \mathcal{L}_{\text{txt2img}}) $$ ### 二、微调步骤 #### 1. 数据准备 - 格式要求:图像-文本对 $(I_i, T_i)$ - 增强策略: - 图像:随机裁剪($\pm 5\%$尺寸扰动)、颜色抖动($\Delta \text{brightness}=0.2$) - 文本:随机掩码($p=0.15$)、同义词替换 #### 2. 模型加载 ```python from transformers import CLIPModel model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") ``` #### 3. 参数冻结策略(按需选择) | 数据规模 | 推荐方案 | 训练参数量比 | |---------|---------|-------------| | <10k | 仅训练Projection层 | 0.5% | | 10k-100k| 微调编码器最后3层 | 15-30% | | >100k | 全参数微调 | 100% | #### 4. 损失函数改进 基础对比损失: $$ \mathcal{L}_{\text{contrastive}} = -\log\frac{e^{s_{ii}/\tau}}{\sum_{j}e^{s_{ij}/\tau}} $$ 可添加跨模态注意力损失: $$ \mathcal{L}_{\text{attention}} = \| \text{Attn}_{\text{img}} - \text{Attn}_{\text{txt}} \|_2 $$ #### 5. 训练超参数 ```yaml 学习率: 1e-5 ~ 5e-5 批次大小: 128 (需至少16GB显存) 优化器: AdamW(β1=0.9, β2=0.98) 学习率调度: 余弦退火(T_max=1000) 温度系数τ: 可学习参数(初始值0.07) ``` ### 三、评估指标 1. 图文检索准确率: $$ \text{Recall@k} = \frac{\#(\text{正确结果在前k个})}{\text{总查询数}} $$ 2. 零样本分类准确率 3. 特征相似度分布分析: $$ \text{对齐度} = \frac{1}{N}\sum_{i=1}^N \cos(\mathbf{v}_i^{\text{img}}, \mathbf{v}_i^{\text{txt}}) $$ ### 四、典型改进方案 1. 领域适配器:在Transformer层间插入Adapter模块 $$ \mathbf{h'} = \mathbf{h} + f(\mathbf{h}W_{\text{down}})W_{\text{up}} $$ 2. 混合精度训练:使用AMP自动混合精度 3. 梯度裁剪:阈值设为1.0 ### 五、注意事项 1. 图像尺寸保持与预训练一致(如224x224) 2. 文本长度不超过模型最大限制(CLIP默认77 tokens) 3. 监控模态坍缩现象: $$ \text{检测条件}:\frac{\|\mathbf{V}^\top\mathbf{V}\|_F}{\sqrt{dN}} > 0.95 $$ 4. 使用FP16时设置梯度缩放 ### 六、示例代码结构 ```python # 自定义数据集 class ClipDataset(Dataset): def __getitem__(self, idx): return { "pixel_values": image_processor(images), "input_ids": text_tokenizer(texts) } # 训练循环 for batch in dataloader: outputs = model( input_ids=batch["input_ids"], pixel_values=batch["pixel_values"], return_loss=True ) loss = outputs.loss loss.backward() optimizer.step() lr_scheduler.step() ``` 通过以上方法,在MSCOCO等基准数据集上通常可获得5-15%的性能提升。实际效果取决于目标任务与预训练域的差异程度,建议通过消融实验确定最佳微调策略。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值