7.3 CLIP 模型的增强训练与评估
本项目是一个基于视觉感知预训练模型(如 LaCLIP/CLIP)和 LLAMA 的图像描述生成系统,包括模型训练、零样本评估以及对生成文本的重写功能,通过整合多种模型和技术,实现了对图像内容的理解和生成多样化描述的功能,为图像与文本之间的关联性建模提供了一个端到端的解决方案。
实例7-10:使用TFT处理CSV文件中的缺失值(源码路径:codes/7/que.py)
7.3.1 项目介绍
本项目实现了对预训练模型 LaCLIP/CLIP 的训练工作,以及 LLAMA 模型的训练和过程。通过这些模型,实现了对图像和文本之间语义关联的建模,可以进行零样本评估,生成了多样化和丰富化的图像描述,并提供了文本生成的多样化风格选项,拓展了文生图大模型在图像与文本处理领域的应用和功能。
本项目的主要功能模块如下所示:
1. 训练模块
- 包括模型训练的主要逻辑,包括模型初始化、损失函数定义、优化器设置、数据加载等。
- 支持分布式训练,使用了分布式数据并行和混合精度训练(AMP)技术。
- 提供了模型训练过程中的日志记录和模型参数保存功能。
2. 零样本评估模块(eval_zeroshot_imagenet.py)
- 实现了对图像零样本分类的评估功能,使用预训练的模型对图像进行分类,同时支持多种文本模板。
- 包括数据加载、模型预测和准确率计算等功能。
3. LLAMA 模型加载与应用模块
- 提供了 LLAMA 模型的加载和应用功能。
- 支持并行加载多个模型,用于多 GPU 或分布式环境下的推理。
- 实现了对给定文本的生成,支持不同的样式选择和文本风格的扩展。
4. 模型并行设置模块
实现了模型并行的设置,用于在多 GPU 或分布式环境下进行模型的初始化和加载,包括了分布式训练的初始化、设备分配等功能。
5. 辅助功能模块
- 包含了一些辅助函数,如平均计量器(AverageMeter)、进度条显示器(ProgressMeter)等,用于训练过程中的统计和显示。
- 实现了一些工具函数,用于模型的初始化、数据加载等。
上述功能模块共同构成了项目的主要功能,包括模型训练、零样本评估和 LLAMA 模型的加载与应用等。
7.3.2 定义数据集
文件data.py定义了两个数据集类 CsvDatasetAugCap 和 CsvDataset,分别用于从 CSV 文件中加载图像及其对应的文本描述。同时提供了函数 get_csv_dataset 和 get_data 用于获取训练和验证数据集。
class CsvDatasetAugCap(Dataset):
def __init__(self, input_filename, transforms, tokenizer=None, root=None, augmented_caption_filelist=None):
logging.debug(f'Loading csv data from {input_filename}.')
self.images = []
self.captions = []
self.root = root
assert input_filename.endswith('.csv')
assert augmented_caption_filelist is not None, 'augmented_caption_filelist is None, use csvdataset instead'
num_augcap = len(augmented_caption_filelist)
augmented_captions = []
file_length = []
for f in augmented_caption_filelist:
with open(f, 'r') as file:
cur_captions = file.readlines()
file_length.append(len(cur_captions))
augmented_captions.append(cur_captions)
assert len(augmented_captions) == num_augcap, 'number of augmented captions is not equal to num_augcap'
for i in range(num_augcap):
assert file_length[i] == file_length[0], 'number of captions in each file is not the same'
num_samples = file_length[0]
with open(input_filename, 'r') as csv_file:
csv_reader = csv.reader(csv_file)
row_index = 0
for row in tqdm(csv_reader):
image = row[0]
prompt = row[1]
if image.endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(self.root, image)
self.images.append(image_path)
if row_index < num_samples:
self.captions.append([prompt])
for augcap_idx in range(num_augcap):
self.captions[row_index].append(augmented_captions[augcap_idx][row_index].replace('\n',''))
assert len(self.captions[row_index]) == num_augcap + 1, 'number of captions is not equal to num_augcap + 1'
row_index += 1
assert row_index % num_samples == 0, 'number of samples in csv is not equal to num_samples in new caption'
self.num_samples = num_samples
self.transforms = transforms
logging.debug('Done loading data.')
self.tokenizer = tokenizer
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
images = self.transforms(Image.open(str(self.images[idx])))
caption_list = self.captions[idx%self.num_samples]
caption = random.choice(caption_list)
if len(caption.split(' ')) < 2:
caption = caption_list[0]
texts = caption
texts = self.tokenizer(str(texts))
return images, texts
class CsvDataset(Dataset):
def __init__(self, input_filename, transforms, tokenizer=None, root=None):
logging.debug(f'Loading csv data from {input_filename}.')
self.images = []
self.captions = []
self.root = root
assert input_filename.endswith('.csv')
with open(input_filename, 'r') as csv_file:
csv_reader = csv.reader(csv_file)
for row in tqdm(csv_reader):
image = row[0]
prompt = row[1]
if image.endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(self.root, image)
self.images.append(image_path)
self.captions.append(prompt)
self.transforms = transforms
logging.debug('Done loading data.')
self.tokenizer = tokenizer
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
images = self.transforms(Image.open(str(self.images[idx])))
texts = self.tokenizer(str(self.captions[idx]))
return images, texts
@dataclass
class DataInfo:
dataloader: DataLoader
sampler: DistributedSampler = None
def set_epoch(self, epoch):
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch)
def get_csv_dataset(args, preprocess_fn, is_train, tokenizer=None, aug_text=False):
input_filename = args.train_data if is_train else args.val_data
assert input_filename
if args.aug_text:
augmented_caption_filelist = args.augmented_caption_filelist
dataset = CsvDatasetAugCap(
input_filename,
preprocess_fn,
root=args.root,
tokenizer=tokenizer,
augmented_caption_filelist=augmented_caption_filelist,
)
else:
dataset = CsvDataset(
input_filename,
preprocess_fn,
root=args.root,
tokenizer=tokenizer
)
num_samples = len(dataset)
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
shuffle = is_train and sampler is None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
def get_data(args, preprocess_fns, tokenizer=None):
preprocess_train, preprocess_val = preprocess_fns
data = {"train": get_csv_dataset(args, preprocess_train, is_train=True, tokenizer=tokenizer)}
return data
对上述代码的具体说明如下所示:
(1)类CsvDatasetAugCap的功能是从 CSV 文件中加载图像及其对应的原始文本描述和增强文本描述。它会将原始文本描述和多个增强文本描述合并成一个列表,支持随机选择其中一个文本描述作为输出。
(2)类CsvDataset的功能是从 CSV 文件中加载图像及其对应的单个原始文本描述。它用于加载普通的图像数据集,每个图像只有一个文本描述。
(3)类DataInfo是一个数据类,用于存储数据加载器和分布式采样器。它提供了设置 epoch 的方法,用于分布式训练时更新采样器的 epoch。
(4)函数get_csv_dataset的功能是根据参数返回相应的数据集对象,支持训练和验证数据集的加载,并根据是否启用增强文本描述来选择使用 CsvDatasetAugCap 或 CsvDataset 类。它还设置了数据加载器,并根据是否分布式训练选择是否使用分布式采样器。
(5)函数get_data的功能是获取训练数据集,并返回一个字典,包含训练数据集的信息。
7.3.3 创建模型
文件models.py创建了“视觉-文本检索(CLIP)”模型的各个组件和不同规模变体的构建方法,包括自注意力机制、残差注意力块、Transformer 架构以及 Vision Transformer,并提供了不同规模的 CLIP 模型构建方法,以便在视觉和文本之间进行编码和检索。文件models.py的具体实现流程如下所示。
(1)实现自定义的类 LayerNorm,此类继承自 PyTorch 中的 nn.LayerNorm 类,用于处理 fp16 数据。在 forward 方法中,将输入张量转换为 float32 类型后再调用父类的 forward 方法进行处理,并最终将结果转换回原始的数据类型。
class LayerNorm(nn.LayerNorm):
"""子类化 torch 的 LayerNorm 以处理 fp16。"""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
(2)在类QuickGELU中定义了一个快速的 GELU(Gaussian Error Linear Unit)激活函数,对输入张量 x 进行 GELU 操作,即,用于神经网络的非线性变换。
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
(3)类ResidualAttentionBlock定义了一个残差注意力块(Residual Attention Block),其中包括多头注意力机制和前馈神经网络(MLP)。该块首先对输入进行层归一化,然后通过多头注意力机制处理输入,接着再进行层归一化,并通过一个包含快速GELU激活函数的前馈神经网络处理。最终将这些处理后的结果与输入进行残差连接并返回,增强了模型的表示能力和训练稳定性。
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
(4)类Transformer定义了一个Transformer模块,包括多个残差注意力块(Residual Attention Block)。每个残差注意力块包括多头注意力机制和前馈神经网络。该模块初始化时,指定的层数、宽度和头数用于构建这些残差注意力块。在前向传递中,输入将依次通过这些残差注意力块进行处理,最终返回处理后的结果。
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
(5)类CLIP定义了一个CLIP模型,用于将视觉和文本信息编码为嵌入向量。模型包括视觉模型和文本模型,使用了Transformer机制来处理文本序列,并通过注意力机制进行处理。模型初始化时设置了各种参数和嵌入层,在前向传递过程中分别编码图像和文本,并返回它们的嵌入向量和一个用于缩放对数几率的参数。
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# 视觉
vision_width: int,
vision_model: nn.Module,
# 文本
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
**kwargs,
):
super().__init__()
self.context_length = context_length
self.vision_width = vision_width
self.visual = vision_model
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# 延迟创建因果注意掩码,视觉令牌之间的完全注意力
# pytorch 使用加性注意掩码;填充为 -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # 清除下对角线
return mask
def encode_image(self, image):
x = self.visual(image)
x = x @ self.image_projection
return x
def encode_text(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# 从 eot 嵌入中提取特征(eot_token 是每个序列中最大的数)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_embed = self.encode_image(image)
text_embed = self.encode_text(text)
return {'image_embed': image_embed,
'text_embed': text_embed,
'logit_scale': self.logit_scale.exp()}
(6)类Attention定义了一个注意力机制模块,用于计算输入张量的自注意力。通过查询(query)、键(key)、值(value)向量的线性变换和归一化,模块计算注意力权重并应用于值向量,最终输出经过投影和丢弃层处理的结果。该模块支持多头注意力,并可以选择性地在查询和键上应用LayerNorm以提高训练稳定性。
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_norm=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# todo: 添加 q 和 k 的归一化以提高训练稳定性
self.q_norm = nn.LayerNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = nn.LayerNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # 使 torchscript 满意(不能将张量用作元组)
# todo: 对查询和键应用归一化
q = self.q_norm(q)
k = self.k_norm(k)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
(7)类Block定义了一个 Transformer 块,每个块由一个归一化层、一个注意力层、一个 MLP 层(包含隐藏层和激活函数),以及一个可选的随机深度的 drop path 组成。在前向传播中,输入张量先经过注意力层和 MLP 层的处理,并分别添加 skip connection,以提高网络的稳定性和性能。
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_norm=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
attn_drop=attn_drop, proj_drop=drop)
# 注意: 对于随机深度的drop path, 我们将看到这是否比dropout更好
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
(8)类VisionTransformer定义了一个支持全局平均池化的视觉 Transformer。如果启用了 qk_norm(query 和 key 的归一化),它会重新初始化 Transformer 的块(blocks),使用提供的参数(如 embed_dim、num_heads、mlp_ratio 等)创建一系列的 Block 对象,并将它们组织成一个顺序容器。
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
def __init__(self, qk_norm=False, **kwargs):
super(VisionTransformer, self).__init__(**kwargs)
if qk_norm:
del self.blocks
embed_dim = kwargs['embed_dim']
num_heads = kwargs['num_heads']
mlp_ratio = kwargs['mlp_ratio']
qkv_bias = kwargs['qkv_bias']
depth = kwargs['depth']
drop_rate = 0.
attn_drop_rate = 0.
drop_path_rate = 0.
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
act_layer=act_layer)
for i in range(depth)])
(9)函数vit_small_patch16_224定义了一个小型的 Vision Transformer 模型,其特征包括 16x16 的图像块大小、384 的嵌入维度、12 层深度、12 个注意力头和 4 的 MLP 比例。这个模型使用偏置的 qkv(query、key、value)以及带有 epsilon 值为 1e-6 的 LayerNorm 归一化层。该函数通过将这些参数传递给 VisionTransformer 类并返回实例化的模型来创建和返回 Vision Transformer 模型。
def vit_small_patch16_224(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
(10)函数vit_base_patch16_224定义了一个基础版的 Vision Transformer 模型,其特征包括 16x16 的图像块大小、768 的嵌入维度、12 层深度、12 个注意力头和 4 的 MLP 比例。这个模型使用偏置的 qkv(query、key、value)以及带有 epsilon 值为 1e-6 的 LayerNorm 归一化层。该函数通过将这些参数传递给 VisionTransformer 类并返回实例化的模型来创建和返回 Vision Transformer 模型。
def vit_base_patch16_224(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
(11)函数vit_large_patch16_224定义了一个大型的 Vision Transformer 模型,其特征包括 16x16 的图像块大小、1024 的嵌入维度、24 层深度、16 个注意力头和 4 的 MLP 比例。这个模型使用偏置的 qkv(query、key、value)以及带有 epsilon 值为 1e-6 的 LayerNorm 归一化层。该函数通过将这些参数传递给 VisionTransformer 类并返回实例化的模型来创建和返回 Vision Transformer 模型。
def vit_large_patch16_224(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
(12)函数CLIP_VITS16定义了一个 CLIP 模型,使用了一个小型的 Vision Transformer (ViT) 模型作为视觉编码器。具体来说,函数首先创建一个具有 qk_norm 的 vit_small_patch16_224 模型,设置 num_classes=0。然后,这个 ViT 模型被传递给 CLIP 类进行实例化,并配置了嵌入维度、视觉宽度、上下文长度、词汇表大小、Transformer 的宽度、头数和层数等参数。函数最后返回这个配置好的 CLIP 模型。
def CLIP_VITS16(**kwargs):
vision_model = vit_small_patch16_224(qk_norm=True, num_classes=0)
model = CLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408,
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
return model
(13)函数CLIP_VITB16定义了一个 CLIP 模型,使用了一个基础规模的 Vision Transformer (ViT) 模型作为视觉编码器。具体来说,函数CLIP_VITB16首先创建一个具有 qk_norm 的 vit_base_patch16_224 模型,设置 num_classes=0。然后,这个 ViT 模型被传递给 CLIP 类进行实例化,并配置了嵌入维度、视觉宽度、上下文长度、词汇表大小、Transformer 的宽度、头数和层数等参数。函数最后返回这个配置好的 CLIP 模型。
def CLIP_VITB16(**kwargs):
vision_model = vit_base_patch16_224(qk_norm=True, num_classes=0)
model = CLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408,
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
return model
(14)函数CLIP_VITL16定义了一个 CLIP 模型,使用了一个大规模的 Vision Transformer (ViT) 模型作为视觉编码器。具体来说,函数CLIP_VITL16首先创建了一个具有 qk_norm 的 vit_large_patch16_224 模型,设置 num_classes=0。然后,这个 ViT 模型被传递给 CLIP 类进行实例化,并配置了嵌入维度、视觉宽度、上下文长度、词汇表大小、Transformer 的宽度、头数和层数等参数。函数最后返回这个配置好的 CLIP 模型。
def CLIP_VITL16(**kwargs):
vision_model = vit_large_patch16_224(qk_norm=True, num_classes=0)
model = CLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408,
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
return model