CLIP-ReID代码解读六——model文件夹make_model.py和clip.py

make_model.py

定义了一个名为 build_transformer 的 PyTorch 神经网络模型类,以及一些辅助函数和类。该模型主要用于基于视觉Transformer和卷积神经网络(CNN)的任务。代码分为几个部分,包括权重初始化函数、Transformer模型类、以及加载预训练模型的函数。以下是每个部分的详细注释:

import torch
import torch.nn as nn
import numpy as np
from .clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

# 权重初始化函数,使用Kaiming初始化方法
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
        nn.init.constant_(m.bias, 0.0)
    elif classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
    elif classname.find('BatchNorm') != -1:
        if m.affine:
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0.0)

# 权重初始化函数,专门用于分类器
def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.normal_(m.weight, std=0.001)
        if m.bias:
            nn.init.constant_(m.bias, 0.0)

# Transformer 模型类
class build_transformer(nn.Module):
    def __init__(self, num_classes, camera_num, view_num, cfg):
        super(build_transformer, self).__init__()
        self.model_name = cfg.MODEL.NAME
        self.cos_layer = cfg.MODEL.COS_LAYER
        self.neck = cfg.MODEL.NECK
        self.neck_feat = cfg.TEST.NECK_FEAT

        # 根据模型名称初始化输入通道数
        if self.model_name == 'ViT-B-16':
            self.in_planes = 768
            self.in_planes_proj = 512
        elif self.model_name == 'RN50':
            self.in_planes = 2048
            self.in_planes_proj = 1024

        self.num_classes = num_classes
        self.camera_num = camera_num
        self.view_num = view_num
        self.sie_coe = cfg.MODEL.SIE_COE

        # 定义分类器并应用权重初始化
        self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
        self.classifier.apply(weights_init_classifier)
        self.classifier_proj = nn.Linear(self.in_planes_proj, self.num_classes, bias=False)
        self.classifier_proj.apply(weights_init_classifier)

        # 初始化两个一维批量归一化层
        self.bottleneck = nn.BatchNorm1d(self.in_planes)
        self.bottleneck.bias.requires_grad_(False)
        self.bottleneck.apply(weights_init_kaiming)
        self.bottleneck_proj = nn.BatchNorm1d(self.in_planes_proj)
        self.bottleneck_proj.bias.requires_grad_(False)
        self.bottleneck_proj.apply(weights_init_kaiming)

        # 计算图像的分辨率和步幅大小
        self.h_resolution = int((cfg.INPUT.SIZE_TRAIN[0]-16)//cfg.MODEL.STRIDE_SIZE[0] + 1)  # 16
        self.w_resolution = int((cfg.INPUT.SIZE_TRAIN[1]-16)//cfg.MODEL.STRIDE_SIZE[1] + 1)  # 8
        self.vision_stride_size = cfg.MODEL.STRIDE_SIZE[0]
        
        # 加载CLIP模型
        clip_model = load_clip_to_cpu(self.model_name, self.h_resolution, self.w_resolution, self.vision_stride_size)
        clip_model.to("cuda")
        self.image_encoder = clip_model.visual

        # 定义相机和视图嵌入
        if cfg.MODEL.SIE_CAMERA and cfg.MODEL.SIE_VIEW:
            self.cv_embed = nn.Parameter(torch.zeros(camera_num * view_num, self.in_planes))
            trunc_normal_(self.cv_embed, std=.02)
            print('camera number is : {}'.format(camera_num))
        elif cfg.MODEL.SIE_CAMERA:
            self.cv_embed = nn.Parameter(torch.zeros(camera_num, self.in_planes))
            trunc_normal_(self.cv_embed, std=.02)
            print('camera number is : {}'.format(camera_num))
        elif cfg.MODEL.SIE_VIEW:
            self.cv_embed = nn.Parameter(torch.zeros(view_num, self.in_planes))
            trunc_normal_(self.cv_embed, std=.02)
            print('camera number is : {}'.format(view_num))

    def forward(self, x, label=None, cam_label= None, view_label=None):
        if self.model_name == 'RN50':
            image_features_last, image_features, image_features_proj = self.image_encoder(x) #B,512  B,128,512
            img_feature_last = nn.functional.avg_pool2d(image_features_last, image_features_last.shape[2:4]).view(x.shape[0], -1) 
            img_feature = nn.functional.avg_pool2d(image_features, image_features.shape[2:4]).view(x.shape[0], -1) 
            img_feature_proj = image_features_proj[0]

        elif self.model_name == 'ViT-B-16':
            if cam_label != None and view_label!=None:
                cv_embed = self.sie_coe * self.cv_embed[cam_label * self.view_num + view_label]
            elif cam_label != None:
                cv_embed = self.sie_coe * self.cv_embed[cam_label]
            elif view_label!=None:
                cv_embed = self.sie_coe * self.cv_embed[view_label]
            else:
                cv_embed = None
            image_features_last, image_features, image_features_proj = self.image_encoder(x, cv_embed) #B,512  B,128,512
            img_feature_last = image_features_last[:,0]
            img_feature = image_features[:,0]
            img_feature_proj = image_features_proj[:,0]

        feat = self.bottleneck(img_feature) 
        feat_proj = self.bottleneck_proj(img_feature_proj) 

        if self.training:
            cls_score = self.classifier(feat)
            cls_score_proj = self.classifier_proj(feat_proj)
            return [cls_score, cls_score_proj], [img_feature_last, img_feature, img_feature_proj]

        else:
            if self.neck_feat == 'after':
                # print("Test with feature after BN")
                return torch.cat([feat, feat_proj], dim=1)
            else:
                return torch.cat([img_feature, img_feature_proj], dim=1)

    # 加载预训练模型参数
    def load_param(self, trained_path):
        param_dict = torch.load(trained_path)
        for i in param_dict:
            self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
        print('Loading pretrained model from {}'.format(trained_path))

    # 加载预训练模型参数用于微调
    def load_param_finetune(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            self.state_dict()[i].copy_(param_dict[i])
        print('Loading pretrained model for finetuning from {}'.format(model_path))

# 创建模型实例
def make_model(cfg, num_class, camera_num, view_num):
    model = build_transformer(num_class, camera_num, view_num, cfg)
    return model

# 加载 CLIP 模型到 CPU
from .clip import clip
def load_clip_to_cpu(backbone_name, h_resolution, w_resolution, vision_stride_size):
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # 加载 JIT 存档
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None
    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    model = clip.build_model(state_dict or model.state_dict(), h_resolution, w_resolution, vision_stride_size)
    return model

上述代码的主要功能是定义一个基于视觉Transformer的神经网络模型,包括初始化权重、前向传播、加载预训练参数等操作。模型使用了CLIP作为视觉编码器,并根据配置文件中的参数进行相应的初始化和设置,得到baseline模型。

其中加载clip模型的代码部分,torch.jit.load部分解释如下:
加载一个使用TorchScript保存的JIT(Just-In-Time编译)模型,并将其设置为评估模式(evaluation mode)。

具体解释如下:

# 从给定路径加载一个TorchScript模型并将其映射到CPU
model = torch.jit.load(model_path, map_location="cpu").eval()
# 由于模型已经通过torch.jit.load加载,因此将state_dict设置为None
state_dict = None
  1. torch.jit.load(model_path, map_location="cpu"):

    • torch.jit.load 是 PyTorch 中用于加载TorchScript模型的函数。TorchScript是一种中间表示,可以通过JIT编译来优化模型,并使其可以独立于Python环境运行。
    • model_path 是模型文件的路径,这个文件是通过TorchScript保存的。
    • map_location="cpu" 指定将模型加载到CPU上。如果模型原本保存时在GPU上,但你希望在CPU上进行推理或训练,可以使用这个参数将模型映射到CPU。
  2. .eval():

    • eval() 方法将模型设置为评估模式。评估模式主要影响一些特定层(如Dropout和BatchNorm),它们在训练模式和评估模式下的行为是不同的。在评估模式下,这些层的行为会固定,以确保推理的一致性。
  3. state_dict = None:

    • 这行代码将 state_dict 变量设置为 Nonestate_dict 通常用于保存或加载模型的权重,但在这里因为模型是通过 torch.jit.load 加载的,所以不需要显式地加载 state_dict

综上,这行代码的作用是加载一个TorchScript保存的JIT模型,并将其设置为评估模式,确保在推理时模型行为的一致性。此时,模型权重和架构已经通过 torch.jit.load 加载,因此 state_dict 被设置为 None

这样加载的模型可以直接用于推理,且由于设置为评估模式,Dropout和BatchNorm等层将不会在推理时引入不确定性。

clip.py

实现了一个用于下载、加载和预处理CLIP模型的Python模块。CLIP模型是一种用于图像和文本对齐的模型。代码分为多个部分,包括模型下载、模型加载、图像预处理和文本标记化。以下是每个部分的详细注释:

import hashlib
import os
import urllib
import warnings
from typing import Union, List

import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm

from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

# 检查PyTorch版本是否符合要求
if torch.__version__.split(".") < ["1", "7", "1"]:
    warnings.warn("PyTorch version 1.7.1 or higher is recommended")

__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()

# 模型的URL地址
_MODELS = {
    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
    "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
    "ViT-B-32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
    "ViT-B-16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}

# 下载模型文件
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)

    expected_sha256 = url.split("/")[-2]
    download_target = os.path.join(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    # 检查文件是否已经存在并且SHA256校验和匹配
    if os.path.isfile(download_target):
        if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
            return download_target
        else:
            warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

    # 下载文件并显示进度条
    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    # 检查下载文件的SHA256校验和是否匹配
    if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
        raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")

    return download_target

# 图像预处理变换
def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        lambda image: image.convert("RGB"),
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

# 获取可用模型名称
def available_models() -> List[str]:
    """Returns the names of available CLIP models"""
    return list(_MODELS.keys())

# 加载CLIP模型
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
    """Load a CLIP model

    Parameters
    ----------
    name : str
        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict

    device : Union[str, torch.device]
        The device to put the loaded model

    jit : bool
        Whether to load the optimized JIT model or more hackable non-JIT model (default).

    Returns
    -------
    model : torch.nn.Module
        The CLIP model

    preprocess : Callable[[PIL.Image], torch.Tensor]
        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
    """
    if name in _MODELS:
        model_path = _download(_MODELS[name])
    elif os.path.isfile(name):
        model_path = name
    else:
        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

    try:
        # 尝试加载JIT存档
        model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
        state_dict = None
    except RuntimeError:
        # 加载保存的状态字典
        if jit:
            warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
            jit = False
        state_dict = torch.load(model_path, map_location="cpu")

    if not jit:
        model = build_model(state_dict or model.state_dict()).to(device)
        if str(device) == "cpu":
            model.float()
        return model, _transform(model.visual.input_resolution)

    # 修补设备名称
    device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
    device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

    def patch_device(module):
        try:
            graphs = [module.graph] if hasattr(module, "graph") else []
        except RuntimeError:
            graphs = []

        if hasattr(module, "forward1"):
            graphs.append(module.forward1.graph)

        for graph in graphs:
            for node in graph.findAllNodes("prim::Constant"):
                if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
                    node.copyAttributes(device_node)

    model.apply(patch_device)
    patch_device(model.encode_image)
    patch_device(model.encode_text)

    # 修补CPU上的dtype为float32
    if str(device) == "cpu":
        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
        float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
        float_node = float_input.node()

        def patch_float(module):
            try:
                graphs = [module.graph] if hasattr(module, "graph") else []
            except RuntimeError:
                graphs = []

            if hasattr(module, "forward1"):
                graphs.append(module.forward1.graph)

            for graph in graphs:
                for node in graph.findAllNodes("aten::to"):
                    inputs = list(node.inputs())
                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()
                        if inputs[i].node()["value"] == 5:
                            inputs[i].node().copyAttributes(float_node)

        model.apply(patch_float)
        patch_float(model.encode_image)
        patch_float(model.encode_text)

        model.float()

    return model, _transform(model.input_resolution.item())

# 文本标记化
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
    """
    Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize

    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    truncate: bool
        Whether to truncate the text in case its encoding is longer than the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
   

 """
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder[""]
    eot_token = _tokenizer.encoder[""]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

这段代码的主要功能是提供CLIP模型的下载、加载、图像预处理和文本标记化功能,确保模型和数据能够正确地用于训练和推理。

  • 其中,检查指定的文件是否存在并验证其SHA256校验和是否匹配的代码详细解释如下:
    如果文件存在且校验和匹配,则返回文件路径;否则发出警告,表示文件存在但校验和不匹配,建议重新下载文件。

具体来说:

  1. if os.path.isfile(download_target):

    • 检查 download_target 指定的文件是否存在。
  2. if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:

    • 如果文件存在,计算文件的SHA256校验和,并将其与 expected_sha256 进行比较。
  3. return download_target

    • 如果校验和匹配,则返回文件路径 download_target
  4. else:

    • 如果校验和不匹配,则执行以下代码:
  5. warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

    • 发出警告,表示文件存在但校验和不匹配,建议重新下载文件。

这段代码确保下载的文件完整且未被篡改。具体代码如下:

import os
import hashlib
import warnings

def check_and_verify_file(download_target, expected_sha256):
    if os.path.isfile(download_target):
        with open(download_target, "rb") as f:
            file_hash = hashlib.sha256(f.read()).hexdigest()
        if file_hash == expected_sha256:
            return download_target
        else:
            warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
            # 这里可以加入重新下载文件的代码逻辑

这段代码定义了一个函数 check_and_verify_file,用于检查文件是否存在并验证其SHA256校验和。

  • 此外,从指定的URL下载文件,并在下载过程中显示一个进度条的代码详细解释:
  1. with urllib.request.urlopen(url) as source:

    • 打开URL,并将其作为 source 进行处理。urllib.request.urlopen(url) 返回一个类文件对象,可以用来读取URL的数据。
  2. open(download_target, "wb") as output:

    • 打开一个文件进行写操作,文件路径为 download_target,模式为 “wb”(写二进制文件)。output 是一个文件对象,用来将下载的数据写入本地文件。
  3. with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:

    • 使用 tqdm 库来创建一个进度条,total 参数设置进度条的总长度,这里从 source.info().get("Content-Length") 获取文件的总字节数并转换为整数。ncols=80 设置进度条的宽度为80个字符,unit='iB'unit_scale=True 设置进度条的单位为字节并自动缩放单位(如KB、MB等)。
  4. while True::

    • 开始一个无限循环,用于逐块读取并写入文件,直到所有数据都被下载。
  5. buffer = source.read(8192):

    • source 中读取8192字节的数据块。如果到达文件末尾,read 方法会返回空字节串 b''
  6. if not buffer: break:

    • 如果读取到的数据块为空(即到达文件末尾),跳出循环。
  7. output.write(buffer):

    • 将读取到的数据块写入本地文件 output
  8. loop.update(len(buffer)):

    • 更新进度条,len(buffer) 是刚刚读取的数据块的字节数。tqdm 会根据这个值更新进度条的状态。

通过这段代码,文件会被从URL下载到本地指定的路径,并且下载过程中的进度会以进度条的形式显示在控制台上。

完整的代码如下:

with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
    with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
        while True:
            buffer = source.read(8192)
            if not buffer:
                break
            output.write(buffer)
            loop.update(len(buffer))

该代码块确保在下载过程中显示下载进度,并且每次读取一块数据写入本地文件,直到整个文件下载完成。

  • 18
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yiruzhao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值