make_model.py
定义了一个名为 build_transformer
的 PyTorch 神经网络模型类,以及一些辅助函数和类。该模型主要用于基于视觉Transformer和卷积神经网络(CNN)的任务。代码分为几个部分,包括权重初始化函数、Transformer模型类、以及加载预训练模型的函数。以下是每个部分的详细注释:
import torch
import torch.nn as nn
import numpy as np
from .clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
# 权重初始化函数,使用Kaiming初始化方法
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
nn.init.constant_(m.bias, 0.0)
elif classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif classname.find('BatchNorm') != -1:
if m.affine:
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
# 权重初始化函数,专门用于分类器
def weights_init_classifier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight, std=0.001)
if m.bias:
nn.init.constant_(m.bias, 0.0)
# Transformer 模型类
class build_transformer(nn.Module):
def __init__(self, num_classes, camera_num, view_num, cfg):
super(build_transformer, self).__init__()
self.model_name = cfg.MODEL.NAME
self.cos_layer = cfg.MODEL.COS_LAYER
self.neck = cfg.MODEL.NECK
self.neck_feat = cfg.TEST.NECK_FEAT
# 根据模型名称初始化输入通道数
if self.model_name == 'ViT-B-16':
self.in_planes = 768
self.in_planes_proj = 512
elif self.model_name == 'RN50':
self.in_planes = 2048
self.in_planes_proj = 1024
self.num_classes = num_classes
self.camera_num = camera_num
self.view_num = view_num
self.sie_coe = cfg.MODEL.SIE_COE
# 定义分类器并应用权重初始化
self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False)
self.classifier.apply(weights_init_classifier)
self.classifier_proj = nn.Linear(self.in_planes_proj, self.num_classes, bias=False)
self.classifier_proj.apply(weights_init_classifier)
# 初始化两个一维批量归一化层
self.bottleneck = nn.BatchNorm1d(self.in_planes)
self.bottleneck.bias.requires_grad_(False)
self.bottleneck.apply(weights_init_kaiming)
self.bottleneck_proj = nn.BatchNorm1d(self.in_planes_proj)
self.bottleneck_proj.bias.requires_grad_(False)
self.bottleneck_proj.apply(weights_init_kaiming)
# 计算图像的分辨率和步幅大小
self.h_resolution = int((cfg.INPUT.SIZE_TRAIN[0]-16)//cfg.MODEL.STRIDE_SIZE[0] + 1) # 16
self.w_resolution = int((cfg.INPUT.SIZE_TRAIN[1]-16)//cfg.MODEL.STRIDE_SIZE[1] + 1) # 8
self.vision_stride_size = cfg.MODEL.STRIDE_SIZE[0]
# 加载CLIP模型
clip_model = load_clip_to_cpu(self.model_name, self.h_resolution, self.w_resolution, self.vision_stride_size)
clip_model.to("cuda")
self.image_encoder = clip_model.visual
# 定义相机和视图嵌入
if cfg.MODEL.SIE_CAMERA and cfg.MODEL.SIE_VIEW:
self.cv_embed = nn.Parameter(torch.zeros(camera_num * view_num, self.in_planes))
trunc_normal_(self.cv_embed, std=.02)
print('camera number is : {}'.format(camera_num))
elif cfg.MODEL.SIE_CAMERA:
self.cv_embed = nn.Parameter(torch.zeros(camera_num, self.in_planes))
trunc_normal_(self.cv_embed, std=.02)
print('camera number is : {}'.format(camera_num))
elif cfg.MODEL.SIE_VIEW:
self.cv_embed = nn.Parameter(torch.zeros(view_num, self.in_planes))
trunc_normal_(self.cv_embed, std=.02)
print('camera number is : {}'.format(view_num))
def forward(self, x, label=None, cam_label= None, view_label=None):
if self.model_name == 'RN50':
image_features_last, image_features, image_features_proj = self.image_encoder(x) #B,512 B,128,512
img_feature_last = nn.functional.avg_pool2d(image_features_last, image_features_last.shape[2:4]).view(x.shape[0], -1)
img_feature = nn.functional.avg_pool2d(image_features, image_features.shape[2:4]).view(x.shape[0], -1)
img_feature_proj = image_features_proj[0]
elif self.model_name == 'ViT-B-16':
if cam_label != None and view_label!=None:
cv_embed = self.sie_coe * self.cv_embed[cam_label * self.view_num + view_label]
elif cam_label != None:
cv_embed = self.sie_coe * self.cv_embed[cam_label]
elif view_label!=None:
cv_embed = self.sie_coe * self.cv_embed[view_label]
else:
cv_embed = None
image_features_last, image_features, image_features_proj = self.image_encoder(x, cv_embed) #B,512 B,128,512
img_feature_last = image_features_last[:,0]
img_feature = image_features[:,0]
img_feature_proj = image_features_proj[:,0]
feat = self.bottleneck(img_feature)
feat_proj = self.bottleneck_proj(img_feature_proj)
if self.training:
cls_score = self.classifier(feat)
cls_score_proj = self.classifier_proj(feat_proj)
return [cls_score, cls_score_proj], [img_feature_last, img_feature, img_feature_proj]
else:
if self.neck_feat == 'after':
# print("Test with feature after BN")
return torch.cat([feat, feat_proj], dim=1)
else:
return torch.cat([img_feature, img_feature_proj], dim=1)
# 加载预训练模型参数
def load_param(self, trained_path):
param_dict = torch.load(trained_path)
for i in param_dict:
self.state_dict()[i.replace('module.', '')].copy_(param_dict[i])
print('Loading pretrained model from {}'.format(trained_path))
# 加载预训练模型参数用于微调
def load_param_finetune(self, model_path):
param_dict = torch.load(model_path)
for i in param_dict:
self.state_dict()[i].copy_(param_dict[i])
print('Loading pretrained model for finetuning from {}'.format(model_path))
# 创建模型实例
def make_model(cfg, num_class, camera_num, view_num):
model = build_transformer(num_class, camera_num, view_num, cfg)
return model
# 加载 CLIP 模型到 CPU
from .clip import clip
def load_clip_to_cpu(backbone_name, h_resolution, w_resolution, vision_stride_size):
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# 加载 JIT 存档
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
model = clip.build_model(state_dict or model.state_dict(), h_resolution, w_resolution, vision_stride_size)
return model
上述代码的主要功能是定义一个基于视觉Transformer的神经网络模型,包括初始化权重、前向传播、加载预训练参数等操作。模型使用了CLIP作为视觉编码器,并根据配置文件中的参数进行相应的初始化和设置,得到baseline模型。
其中加载clip模型的代码部分,torch.jit.load
部分解释如下:
加载一个使用TorchScript保存的JIT(Just-In-Time编译)模型,并将其设置为评估模式(evaluation mode)。
具体解释如下:
# 从给定路径加载一个TorchScript模型并将其映射到CPU
model = torch.jit.load(model_path, map_location="cpu").eval()
# 由于模型已经通过torch.jit.load加载,因此将state_dict设置为None
state_dict = None
-
torch.jit.load(model_path, map_location="cpu")
:torch.jit.load
是 PyTorch 中用于加载TorchScript模型的函数。TorchScript是一种中间表示,可以通过JIT编译来优化模型,并使其可以独立于Python环境运行。model_path
是模型文件的路径,这个文件是通过TorchScript保存的。map_location="cpu"
指定将模型加载到CPU上。如果模型原本保存时在GPU上,但你希望在CPU上进行推理或训练,可以使用这个参数将模型映射到CPU。
-
.eval()
:eval()
方法将模型设置为评估模式。评估模式主要影响一些特定层(如Dropout和BatchNorm),它们在训练模式和评估模式下的行为是不同的。在评估模式下,这些层的行为会固定,以确保推理的一致性。
-
state_dict = None
:- 这行代码将
state_dict
变量设置为None
。state_dict
通常用于保存或加载模型的权重,但在这里因为模型是通过torch.jit.load
加载的,所以不需要显式地加载state_dict
。
- 这行代码将
综上,这行代码的作用是加载一个TorchScript保存的JIT模型,并将其设置为评估模式,确保在推理时模型行为的一致性。此时,模型权重和架构已经通过 torch.jit.load
加载,因此 state_dict
被设置为 None
。
这样加载的模型可以直接用于推理,且由于设置为评估模式,Dropout和BatchNorm等层将不会在推理时引入不确定性。
clip.py
实现了一个用于下载、加载和预处理CLIP模型的Python模块。CLIP模型是一种用于图像和文本对齐的模型。代码分为多个部分,包括模型下载、模型加载、图像预处理和文本标记化。以下是每个部分的详细注释:
import hashlib
import os
import urllib
import warnings
from typing import Union, List
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
# 检查PyTorch版本是否符合要求
if torch.__version__.split(".") < ["1", "7", "1"]:
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
# 模型的URL地址
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"ViT-B-32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B-16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}
# 下载模型文件
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
# 检查文件是否已经存在并且SHA256校验和匹配
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
# 下载文件并显示进度条
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
# 检查下载文件的SHA256校验和是否匹配
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
# 图像预处理变换
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
# 获取可用模型名称
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
# 加载CLIP模型
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name])
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
try:
# 尝试加载JIT存档
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# 加载保存的状态字典
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
# 修补设备名称
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# 修补CPU上的dtype为float32
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())
# 文本标记化
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder[""]
eot_token = _tokenizer.encoder[""]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
这段代码的主要功能是提供CLIP模型的下载、加载、图像预处理和文本标记化功能,确保模型和数据能够正确地用于训练和推理。
- 其中,检查指定的文件是否存在并验证其SHA256校验和是否匹配的代码详细解释如下:
如果文件存在且校验和匹配,则返回文件路径;否则发出警告,表示文件存在但校验和不匹配,建议重新下载文件。
具体来说:
-
if os.path.isfile(download_target):
- 检查
download_target
指定的文件是否存在。
- 检查
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
- 如果文件存在,计算文件的SHA256校验和,并将其与
expected_sha256
进行比较。
- 如果文件存在,计算文件的SHA256校验和,并将其与
-
return download_target
- 如果校验和匹配,则返回文件路径
download_target
。
- 如果校验和匹配,则返回文件路径
-
else:
- 如果校验和不匹配,则执行以下代码:
-
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
- 发出警告,表示文件存在但校验和不匹配,建议重新下载文件。
这段代码确保下载的文件完整且未被篡改。具体代码如下:
import os
import hashlib
import warnings
def check_and_verify_file(download_target, expected_sha256):
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
file_hash = hashlib.sha256(f.read()).hexdigest()
if file_hash == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
# 这里可以加入重新下载文件的代码逻辑
这段代码定义了一个函数 check_and_verify_file
,用于检查文件是否存在并验证其SHA256校验和。
- 此外,从指定的URL下载文件,并在下载过程中显示一个进度条的代码详细解释:
-
with urllib.request.urlopen(url) as source
:- 打开URL,并将其作为
source
进行处理。urllib.request.urlopen(url)
返回一个类文件对象,可以用来读取URL的数据。
- 打开URL,并将其作为
-
open(download_target, "wb") as output
:- 打开一个文件进行写操作,文件路径为
download_target
,模式为 “wb”(写二进制文件)。output
是一个文件对象,用来将下载的数据写入本地文件。
- 打开一个文件进行写操作,文件路径为
-
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop
:- 使用
tqdm
库来创建一个进度条,total
参数设置进度条的总长度,这里从source.info().get("Content-Length")
获取文件的总字节数并转换为整数。ncols=80
设置进度条的宽度为80个字符,unit='iB'
和unit_scale=True
设置进度条的单位为字节并自动缩放单位(如KB、MB等)。
- 使用
-
while True:
:- 开始一个无限循环,用于逐块读取并写入文件,直到所有数据都被下载。
-
buffer = source.read(8192)
:- 从
source
中读取8192字节的数据块。如果到达文件末尾,read
方法会返回空字节串b''
。
- 从
-
if not buffer: break
:- 如果读取到的数据块为空(即到达文件末尾),跳出循环。
-
output.write(buffer)
:- 将读取到的数据块写入本地文件
output
。
- 将读取到的数据块写入本地文件
-
loop.update(len(buffer))
:- 更新进度条,
len(buffer)
是刚刚读取的数据块的字节数。tqdm
会根据这个值更新进度条的状态。
- 更新进度条,
通过这段代码,文件会被从URL下载到本地指定的路径,并且下载过程中的进度会以进度条的形式显示在控制台上。
完整的代码如下:
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
该代码块确保在下载过程中显示下载进度,并且每次读取一块数据写入本地文件,直到整个文件下载完成。