【提示学习代码】CoOp代码详读

Dassl

  • 基于 PyTorch 的工具包
  • 为什么取名为 “Dassl”?
  • Dassl 将域自适应(DA)和半监督学习(SSL)的首字母结合起来

CoOp代码详读

CoOp是对CLIP的改进工作,主要是对prompt进行学习从而不用来手动设置prompt。

1 load_clip_to_cpu

加载clip模型

def load_clip_to_cpu(cfg):  
    backbone_name = cfg.MODEL.BACKBONE.NAME
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)  
    
    try:
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.jit.load(model_path, map_location='cpu')

    model = clip.build_model(state_dict or model.state_dict())

    return model

1.1

 model = torch.jit.load(model_path, map_location="cpu").eval()
  • torch.jit.load()加载模型
  • .eval()将其设置为评估模型,禁用训练中的dropout和batch normalization

1.2

map_location = "cpu"模型加载在cpu上
map_location = "cuda" if torch.cuda.is_available() else "cpu"模型加载在gpu上
  • map_location = "cpu"模型加载在cpu上
  • map_location = "cuda" if torch.cuda.is_available() else "cpu"模型加载在gpu上

2 TextEncoder

文本编码器,接收文本的输入,并输出相应的编码表示

class TextEncoder(nn.Model):
    def __init__(self, clip_model):
        super.__init__()
        self.transformer = clip_model.transformer
        self.position_embedding = clip_model.position_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.position_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        x = x[torch.arange(x.shape[0]),tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x

2.1

def forward(self, prompts, tokenized_prompts):
  • 定义数据在模型中的前向传播过程
  • 模型接收两个输入,prompts和tokenized_prompts

2.2

x = x.permute(1, 0, 2)
  • .permute()重新排列张量的维度
  • 文本输入的形状为NLD(批量大小N,序列长度L,嵌入维度D)
  • 重新排序后为LND(序列长度L,批量大小N,嵌入维度D)

假设原始为(2,4,3)

[
  [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
  [[13., 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]]
]

permute后变成了(4,2,3)

[
  [[1, 2, 3], [13, 14, 15]],
  [[4, 5, 6], [16, 17, 18]],
  [[7, 8, 9], [19, 20, 21]],
  [[10, 11, 12], [22, 23, 24]]
]

2.3

 x = self.transformer(x)
  • 输入到自注意力变换器中进行编码

2.4

x = x.permute(1, 0, 2)  
  • LND -> NLD,重新变换为嵌入时的张量顺序

2.5

 x = x[torch.arange(x.shape[0]),tokenized_prompts.argmax(dim=-1)] @ self.text_projection
  • x.shape[0]获取x(NLD)的第一个维度,即为批次大小N
  • torch.arange(x.shape[0]) 生成一个从0到N-1的连续整数序列,用来选择每个样本所对应的文本编码结果
  • torch.argmax()返回张量中最大值的索引
  • tokenized_prompts.argmax(dim=-1)计算出每个序列中标记的最大值位置,为每个样本中的最可能的 token 的
  • x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)]从x张量中选择出每个样本对应的最可能token的文本编码结果
  • @ 矩阵乘法操作
  • self.text_projection用于投影文本编码结果到特定维度的权重矩阵
  • @ self.text_projection:将选中的文本编码结果与权重矩阵相乘,得到最终的文本特征表示

3 PromptLearner

Prompt学习器,用于学习生成提示

class PromptLearner(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        n_cls = len(classnames)  
        n_ctx = cfg.TRAINER.COOP.N_CTX  
        ctx_init = cfg.TRAINER.COOP.CTX_INIT 
        dtype = clip_model.dtype  # clip_model 的数据类型
        ctx_dim = clip_model.ln_final.weight.shape[0]  
        clip_imsize = clip_model.visual.input_resolution 
        cfg_imsize = cfg.INPUT.SIZE[0]  # 输入图像的大小
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"  
       
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")  
            n_ctx = len(ctx_init.split(" "))  
            prompt = clip.tokenize(ctx_init)

            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)  
           
            ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
            prompt_prefix = ctx_init

        else:
            if cfg.TRAINER.COOP.CSC: 
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype) 
            else:
                print("Initializing a generic context")  
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)  

            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")
        self.ctx = nn.Parameter(ctx_vectors)
        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        self.class_token_position = cfg.TRAINER.COOP.CLASS_TOKEN_POSITION

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim) 前缀 "a photo of a" 的嵌入向量
                    ctx,  # (n_cls, n_ctx, dim) 上下文向量
                    suffix,  # (n_cls, *, dim) 后缀 "car" 的嵌入向量
                ],
                dim=1,
            )

        elif self.class_token_position == "middle": 
            half_n_ctx = self.n_ctx // 2 
            prompts = []  
            for i in range(self.n_cls):
                name_len = self.name_lens[i]  
                prefix_i = prefix[i: i + 1, :, :]  
                class_i = suffix[i: i + 1, :name_len, :]  
                suffix_i = suffix[i: i + 1, name_len:, :]  
                ctx_i_half1 = ctx[i: i + 1, :half_n_ctx, :]  
                ctx_i_half2 = ctx[i: i + 1, half_n_ctx:, :]  
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,  # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i: i + 1, :, :]
                class_i = suffix[i: i + 1, :name_len, :]
                suffix_i = suffix[i: i + 1, name_len:, :]
                ctx_i = ctx[i: i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,  # (1, name_len, dim)
                        ctx_i,  # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts
  • n_cls:数据集类别数量
  • n_ctx:提示长度
  • ctx_init:提示初始化
  • dtype:clip_model的数据类型
  • ctx_dim:维度
  • CTX_INIT:a photo of a

3.1

ctx_dim = clip_model.ln_final.weight.shape[0]
  • ln_final 是clip模型的最后一层layer normalization,对模型的输出进行了归一化操作
  • weight属性是该层的权重参数,是一个张量(output_features,input_features)->(输出特征的维度,输入特征的维度)
  • shape方法获取张量形状
  • shape[0]表示元组中的第一个元素,为张量第一个维度的大小->output_features
  • 从而确定上下文向量的维度ctx_dim

3.2

clip_imsize = clip_model.visual.input_resolution  # 视觉模块的输入分辨率
cfg_imsize = cfg.INPUT.SIZE[0]  # 输入图像的大小
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
  • SIZE[0]:只需要图像的高度
  • assert:确保配置文件中指定的图像大小与视觉模块期望的输入分辨率相匹配

3.3

ctx_init = ctx_init.replace("_", " ")  
n_ctx = len(ctx_init.split(" "))  # 计算空格分隔后的初始化单词数量(4)
prompt = clip.tokenize(ctx_init)
  • .replace(a,b):把a换成b
  • .replace("_", " ") :把下划线替换为空格
  • clip.tokenize():使用clip模型的tokenize方法,将初始化文本转换为token序列

3.4

with torch.no_grad(): 
       embedding = clip_model.token_embedding(prompt).type(dtype)  
ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
  • with torch.no_grad(): 初始化过程中,禁用梯度计算
  • clip_model.token_embedding(prompt):将token转换为嵌入向量embedding
  • 切片[1: 1 + n_ctx]选择了除起始token之外的n_ctx个token对应的嵌入向量

3.5

nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
  • .init.normal_()对上下文向量进行正态分布的随机初始化标准差0.02
  • ["X"] * n_ctx:创建一个长度为n_ctx的字符串,每个元素都是X
  • " ".join():将创建的字符串中的元素用空格连接起来
  • 得到prompt_prefix:X X X X X X X X X ……
  • X字符在这里作为初始化文本的占位符

3.6

self.ctx = nn.Parameter(ctx_vectors)
  • nn.Parameter():将上下文向量ctv_vectors转化为可优化参数param

3.7

classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
with torch.no_grad():
    embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
  • prompt_prefix + " " + name + ".":将初始化提示与类名进行拼接,形成完整的提示文本
  • 例如:a photo of a car
  • torch.cat():将多个张量拼接成一个新张量
  • tokenized_prompts包含了所有提示文本对应的token序列的张量
  • clip_model.token_embedding():将token序列转化为嵌入向量(embedding)

3.8

self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :])  # CLS, EOS
  • register_buffer():将embedding的前缀后缀部分注册为一个缓冲区
  • 保存在模型的state_dict
  • token_prefix:start of sentence
  • token_suffix:class + end of sentence
 if self.class_token_position == "end":  # token 的位置参数是否设置为 "end"
       prompts = torch.cat(
          [
              prefix,  # (n_cls, 1, dim) 前缀 "a photo of a" start of text的embedding
              ctx,  # (n_cls, n_ctx, dim) 
              suffix,  # (n_cls, *, dim) 后缀 end of text 的embedding
          ],
          dim=1,
       )
  • 根据类别token的位置,拼接提示文本的各个部分,形成最终的输入
  • torch.cat():拼接

4 CustomCLIP

自定义CLIP模型:上述两个模块结合而成的CLIP
将图像和文本进行编码并计算它们之间的相似性

    def forward(self, image):
        image_features = self.image_encoder(image.type(self.dtype))  
        prompts = self.prompt_learner()
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)  
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        logit_scale = self.logit_scale.exp()  # 将对数标度参数进行指数化
        logits = logit_scale * image_features @ text_features.t()

        return logits

4.1

prompts = self.prompt_learner()
  • 调用PromptLearner的forward方法,生成文本提示

4.2

image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
  • 图像、文本特征归一化,确保它们模为1
  • text_features.norm(dim=-1, keepdim=True):计算了特征向量的L2范数(模)
  • keepdim = True:保持结果的维度不变

4.3

 logit_scale = self.logit_scale.exp()  # 将对数标度参数进行指数化
 logits = logit_scale * image_features @ text_features.t()
  • 计算图像特征向量和文本特征向量之间的余弦相似度
  • text_features.t():对文本特征向量进行转置
  • 乘法运算符 @:计算两个矩阵的矩阵乘积
  • logit_scale:作为一个缩放因子,调整两者之间的相似度得分
  • 返回计算的相似度得分

5 CoOp

@TRAINER_REGISTRY.register()
class CoOp(TrainerX):

    # 检查配置文件中的参数
    def check_cfg(self, cfg):
        assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
        # PREC参数值必须是三者其中一个

    # 创建CoOp模型
    def build_model(self):
        cfg = self.cfg
        classnames = self.dm.dataset.classnames

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg)

        if cfg.TRAINER.COOP.PREC == "fp32" or cfg.TRAINER.COOP.PREC == "amp":
            clip_model.float()

        print("Building custom CLIP")
        self.model = CustomCLIP(cfg, classnames, clip_model)
        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.model.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)

        if cfg.MODEL.INIT_WEIGHTS:
            load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)

        self.model.to(self.device) 
        self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
        self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
        self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None  
        device_count = torch.cuda.device_count()
        if device_count > 1:
            print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
            self.model = nn.DataParallel(self.model)

   
    def forward_backward(self, batch):
        image, label = self.parse_batch_train(batch)

        prec = self.cfg.TRAINER.COOP.PREC
        if prec == "amp":  

            with autocast():  
                output = self.model(image)
                loss = F.cross_entropy(output, label)

            self.optim.zero_grad()  # 清空优化器中之前计算的梯度
            self.scaler.scale(loss).backward()  # loss对模型进行反向传播
            self.scaler.step(self.optim)  # 更新模型参数
            self.scaler.update()  # 更新梯度缩放器

        else:  
            output = self.model(image)
            loss = F.cross_entropy(output, label)
            self.model_backward_and_update(loss)  # loss对模型进行反向传播

        loss_summary = {
            "loss": loss.item(),
            "acc": compute_accuracy(output, label)[0].item(),
        }
        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary


    def parse_batch_train(self, batch):
        input = batch["img"]
        label = batch["label"]
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label

    # 加载模型
    def load_model(self, directory, epoch=None):
        if not directory:
            print("Note that load_model() is skipped as no pretrained model is given")
            return

        names = self.get_model_names()

        # 默认加载最佳模型
        model_file = "model-best.pth.tar"

        if epoch is not None:
            model_file = "model.pth.tar-" + str(epoch)

        for name in names:
            model_path = osp.join(directory, name, model_file)

            if not osp.exists(model_path):
                raise FileNotFoundError('Model not found at "{}"'.format(model_path))

            checkpoint = load_checkpoint(model_path)
            state_dict = checkpoint["state_dict"]
            epoch = checkpoint["epoch"]

            # Ignore fixed token vectors
            if "token_prefix" in state_dict:
                del state_dict["token_prefix"]

            if "token_suffix" in state_dict:
                del state_dict["token_suffix"]

            print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
            self._models[name].load_state_dict(state_dict, strict=False)

5.1

    def check_cfg(self, cfg):
        assert cfg.TRAINER.COOP.PREC in ["fp16", "fp32", "amp"]
  • PREC参数值必须是三者其中一个
  • fp32 单精度浮点数:用32位(4字节)来表示一个浮点数
  • fp16 半精度浮点数:用16位(2字节)来表示一个浮点数
  • amp 混合精度训练技术,模型训练中同时使用fp16、fp32,减少显存占用的同时提高训练速度

5.2

 for name, param in self.model.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)
  • 关闭模型中除了 prompt_learner 之外的所有参数的梯度计算
  • .named_parameters():返回一个元组(name,param)
  • requires_grad(False):将参数的梯度设置为False
  • 反向传播中不进行梯度计算,也不会被优化器更新,达到冻结参数的效果

5.3

self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
self.scaler = GradScaler() if cfg.TRAINER.COOP.PREC == "amp" else None  # 设置混合精度训练器
  • build_optimizer(提示学习,配置参数):构建优化器
  • build_lr_scheduler(优化器,配置参数):构建学习率调度器
  • register_model():注册模型和优化器,将模型的prompt_learner、优化器、学习率调度器注册到训练器中

5.4

 if prec == "amp":  # 使用混合精度进行前向传播和反向传播
    with autocast():  # 进入autocast上下文管理器
         output = self.model(image)
         loss = F.cross_entropy(output, label)
    self.optim.zero_grad()  # 清空优化器中之前计算的梯度
    self.scaler.scale(loss).backward()  # loss对模型进行反向传播
    self.scaler.step(self.optim)  # 更新模型参数
    self.scaler.update()  # 更新梯度缩放器
  • self.optim.zero_grad():清空优化器中之前计算的梯度
  • backward(): loss对模型进行反向传播
  • 24
    点赞
  • 50
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值