(7-3-01)CLIP 模型的增强训练与评估:项目介绍+定义数据集+创建模型

7.3  CLIP 模型的增强训练与评估

本项目是一个基于视觉感知预训练模型(如 LaCLIP/CLIP)和 LLAMA 的图像描述生成系统,包括模型训练、零样本评估以及对生成文本的重写功能,通过整合多种模型和技术,实现了对图像内容的理解和生成多样化描述的功能,为图像与文本之间的关联性建模提供了一个端到端的解决方案。

实例7-10使用TFT处理CSV文件中的缺失值(源码路径:codes/7/que.py

7.3.1  项目介绍

本项目实现了对预训练模型 LaCLIP/CLIP 的训练工作,以及 LLAMA 模型的训练和过程。通过这些模型,实现了对图像和文本之间语义关联的建模,可以进行零样本评估,生成了多样化和丰富化的图像描述,并提供了文本生成的多样化风格选项,拓展了文生图大模型在图像与文本处理领域的应用和功能。

本项目的主要功能模块如下所示:

1. 训练模块

  1. 包括模型训练的主要逻辑,包括模型初始化、损失函数定义、优化器设置、数据加载等。
  2. 支持分布式训练,使用了分布式数据并行和混合精度训练(AMP)技术。
  3. 提供了模型训练过程中的日志记录和模型参数保存功能。

2. 零样本评估模块(eval_zeroshot_imagenet.py)

  1. 实现了对图像零样本分类的评估功能,使用预训练的模型对图像进行分类,同时支持多种文本模板。
  2. 包括数据加载、模型预测和准确率计算等功能。

3. LLAMA 模型加载与应用模块

  1. 提供了 LLAMA 模型的加载和应用功能。
  2. 支持并行加载多个模型,用于多 GPU 或分布式环境下的推理。
  3. 实现了对给定文本的生成,支持不同的样式选择和文本风格的扩展。

4. 模型并行设置模块

实现了模型并行的设置,用于在多 GPU 或分布式环境下进行模型的初始化和加载,包括了分布式训练的初始化、设备分配等功能。

5. 辅助功能模块

  1. 包含了一些辅助函数,如平均计量器(AverageMeter)、进度条显示器(ProgressMeter)等,用于训练过程中的统计和显示。
  2. 实现了一些工具函数,用于模型的初始化、数据加载等。

上述功能模块共同构成了项目的主要功能,包括模型训练、零样本评估和 LLAMA 模型的加载与应用等。

7.3.2  定义数据集

文件data.py定义了两个数据集类 CsvDatasetAugCap 和 CsvDataset,分别用于从 CSV 文件中加载图像及其对应的文本描述。同时提供了函数 get_csv_dataset 和 get_data 用于获取训练和验证数据集。

class CsvDatasetAugCap(Dataset):

    def __init__(self, input_filename, transforms, tokenizer=None, root=None, augmented_caption_filelist=None):

        logging.debug(f'Loading csv data from {input_filename}.')

        self.images = []

        self.captions = []

        self.root = root

        assert input_filename.endswith('.csv')

        assert augmented_caption_filelist is not None, 'augmented_caption_filelist is None, use csvdataset instead'

        num_augcap = len(augmented_caption_filelist)

        augmented_captions = []

        file_length = []

        for f in augmented_caption_filelist:

            with open(f, 'r') as file:

                cur_captions = file.readlines()

                file_length.append(len(cur_captions))

                augmented_captions.append(cur_captions)

        assert len(augmented_captions) == num_augcap, 'number of augmented captions is not equal to num_augcap'

        for i in range(num_augcap):

            assert file_length[i] == file_length[0], 'number of captions in each file is not the same'

        num_samples = file_length[0]

        with open(input_filename, 'r') as csv_file:

            csv_reader = csv.reader(csv_file)

            row_index = 0

            for row in tqdm(csv_reader):

                image = row[0]

                prompt = row[1]

                if image.endswith(('.png', '.jpg', '.jpeg')):

                    image_path = os.path.join(self.root, image)

                    self.images.append(image_path)

                    if row_index < num_samples:

                        self.captions.append([prompt])

                        for augcap_idx in range(num_augcap):

                            self.captions[row_index].append(augmented_captions[augcap_idx][row_index].replace('\n',''))

                        assert len(self.captions[row_index]) == num_augcap + 1, 'number of captions is not equal to num_augcap + 1'

                    row_index += 1

            assert row_index % num_samples == 0, 'number of samples in csv is not equal to num_samples in new caption'

        self.num_samples = num_samples

        self.transforms = transforms

        logging.debug('Done loading data.')

        self.tokenizer = tokenizer

    def __len__(self):

        return len(self.images)

    def __getitem__(self, idx):

        images = self.transforms(Image.open(str(self.images[idx])))

        caption_list = self.captions[idx%self.num_samples]

        caption = random.choice(caption_list)

        if len(caption.split(' ')) < 2:

            caption = caption_list[0]

        texts = caption

        texts = self.tokenizer(str(texts))

        return images, texts

class CsvDataset(Dataset):

    def __init__(self, input_filename, transforms, tokenizer=None, root=None):

        logging.debug(f'Loading csv data from {input_filename}.')

        self.images = []

        self.captions = []

        self.root = root

        assert input_filename.endswith('.csv')

        with open(input_filename, 'r') as csv_file:

            csv_reader = csv.reader(csv_file)

            for row in tqdm(csv_reader):

                image = row[0]

                prompt = row[1]

                if image.endswith(('.png', '.jpg', '.jpeg')):

                    image_path = os.path.join(self.root, image)

                    self.images.append(image_path)

                    self.captions.append(prompt)

        self.transforms = transforms

        logging.debug('Done loading data.')

        self.tokenizer = tokenizer

    def __len__(self):

        return len(self.captions)

    def __getitem__(self, idx):

        images = self.transforms(Image.open(str(self.images[idx])))

        texts = self.tokenizer(str(self.captions[idx]))

        return images, texts

@dataclass

class DataInfo:

    dataloader: DataLoader

    sampler: DistributedSampler = None

    def set_epoch(self, epoch):

        if self.sampler is not None and isinstance(self.sampler, DistributedSampler):

            self.sampler.set_epoch(epoch)

def get_csv_dataset(args, preprocess_fn, is_train, tokenizer=None, aug_text=False):

    input_filename = args.train_data if is_train else args.val_data

    assert input_filename

    if args.aug_text:

        augmented_caption_filelist = args.augmented_caption_filelist

        dataset = CsvDatasetAugCap(

            input_filename,

            preprocess_fn,

            root=args.root,

            tokenizer=tokenizer,

            augmented_caption_filelist=augmented_caption_filelist,

        )

    else:

        dataset = CsvDataset(

            input_filename,

            preprocess_fn,

            root=args.root,

            tokenizer=tokenizer

        )

    num_samples = len(dataset)

    sampler = DistributedSampler(dataset) if args.distributed and is_train else None

    shuffle = is_train and sampler is None

    dataloader = DataLoader(

        dataset,

        batch_size=args.batch_size,

        shuffle=shuffle,

        num_workers=args.workers,

        pin_memory=True,

        sampler=sampler,

        drop_last=is_train,

    )

    dataloader.num_samples = num_samples

    dataloader.num_batches = len(dataloader)

    return DataInfo(dataloader, sampler)

def get_data(args, preprocess_fns, tokenizer=None):

    preprocess_train, preprocess_val = preprocess_fns

    data = {"train": get_csv_dataset(args, preprocess_train, is_train=True, tokenizer=tokenizer)}

    return data

对上述代码的具体说明如下所示:

(1)类CsvDatasetAugCap的功能是从 CSV 文件中加载图像及其对应的原始文本描述和增强文本描述。它会将原始文本描述和多个增强文本描述合并成一个列表,支持随机选择其中一个文本描述作为输出。

(2)类CsvDataset的功能是从 CSV 文件中加载图像及其对应的单个原始文本描述。它用于加载普通的图像数据集,每个图像只有一个文本描述。

(3)类DataInfo是一个数据类,用于存储数据加载器和分布式采样器。它提供了设置 epoch 的方法,用于分布式训练时更新采样器的 epoch。

(4)函数get_csv_dataset的功能是根据参数返回相应的数据集对象,支持训练和验证数据集的加载,并根据是否启用增强文本描述来选择使用 CsvDatasetAugCap 或 CsvDataset 类。它还设置了数据加载器,并根据是否分布式训练选择是否使用分布式采样器。

(5)函数get_data的功能是获取训练数据集,并返回一个字典,包含训练数据集的信息。

7.3.3  创建模型

文件models.py创建了“视觉-文本检索(CLIP)”模型的各个组件和不同规模变体的构建方法,包括自注意力机制、残差注意力块、Transformer 架构以及 Vision Transformer,并提供了不同规模的 CLIP 模型构建方法,以便在视觉和文本之间进行编码和检索。文件models.py的具体实现流程如下所示。

(1)实现自定义的类 LayerNorm,此类继承自 PyTorch 中的 nn.LayerNorm 类,用于处理 fp16 数据。在 forward 方法中,将输入张量转换为 float32 类型后再调用父类的 forward 方法进行处理,并最终将结果转换回原始的数据类型。

class LayerNorm(nn.LayerNorm):
    """子类化 torch 的 LayerNorm 以处理 fp16。"""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)

(2)在类QuickGELU中定义了一个快速的 GELU(Gaussian Error Linear Unit)激活函数,对输入张量 x 进行 GELU 操作,即,用于神经网络的非线性变换。

class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

(3)类ResidualAttentionBlock定义了一个残差注意力块(Residual Attention Block),其中包括多头注意力机制和前馈神经网络(MLP)。该块首先对输入进行层归一化,然后通过多头注意力机制处理输入,接着再进行层归一化,并通过一个包含快速GELU激活函数的前馈神经网络处理。最终将这些处理后的结果与输入进行残差连接并返回,增强了模型的表示能力和训练稳定性。

class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

(4)类Transformer定义了一个Transformer模块,包括多个残差注意力块(Residual Attention Block)。每个残差注意力块包括多头注意力机制和前馈神经网络。该模块初始化时,指定的层数、宽度和头数用于构建这些残差注意力块。在前向传递中,输入将依次通过这些残差注意力块进行处理,最终返回处理后的结果。

class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)

(5)类CLIP定义了一个CLIP模型,用于将视觉和文本信息编码为嵌入向量。模型包括视觉模型和文本模型,使用了Transformer机制来处理文本序列,并通过注意力机制进行处理。模型初始化时设置了各种参数和嵌入层,在前向传递过程中分别编码图像和文本,并返回它们的嵌入向量和一个用于缩放对数几率的参数。

class CLIP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # 视觉
                 vision_width: int,
                 vision_model: nn.Module,
                 # 文本
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int,
                 **kwargs,
                 ):
        super().__init__()

        self.context_length = context_length
        self.vision_width = vision_width

        self.visual = vision_model

        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask(),
        )

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)

        self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.initialize_parameters()

    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)

        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.width ** -0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
        nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

    def build_attention_mask(self):
        # 延迟创建因果注意掩码,视觉令牌之间的完全注意力
        # pytorch 使用加性注意掩码;填充为 -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # 清除下对角线
        return mask

    def encode_image(self, image):
        x = self.visual(image)
        x = x @ self.image_projection

        return x

    def encode_text(self, text):
        x = self.token_embedding(text)  # [batch_size, n_ctx, d_model]
        x = x + self.positional_embedding
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # 从 eot 嵌入中提取特征(eot_token 是每个序列中最大的数)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

    def forward(self, image, text):
        image_embed = self.encode_image(image)
        text_embed = self.encode_text(text)

        return {'image_embed': image_embed,
                'text_embed': text_embed,
                'logit_scale': self.logit_scale.exp()}

(6)类Attention定义了一个注意力机制模块,用于计算输入张量的自注意力。通过查询(query)、键(key)、值(value)向量的线性变换和归一化,模块计算注意力权重并应用于值向量,最终输出经过投影和丢弃层处理的结果。该模块支持多头注意力,并可以选择性地在查询和键上应用LayerNorm以提高训练稳定性。

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_norm=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # todo: 添加 q 和 k 的归一化以提高训练稳定性
        self.q_norm = nn.LayerNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity()
        self.k_norm = nn.LayerNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity()

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # 使 torchscript 满意(不能将张量用作元组)

        # todo: 对查询和键应用归一化
        q = self.q_norm(q)
        k = self.k_norm(k)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

(7)类Block定义了一个 Transformer 块,每个块由一个归一化层、一个注意力层、一个 MLP 层(包含隐藏层和激活函数),以及一个可选的随机深度的 drop path 组成。在前向传播中,输入张量先经过注意力层和 MLP 层的处理,并分别添加 skip connection,以提高网络的稳定性和性能。

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_norm=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
                              attn_drop=attn_drop, proj_drop=drop)
        # 注意: 对于随机深度的drop path, 我们将看到这是否比dropout更好
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

(8)类VisionTransformer定义了一个支持全局平均池化的视觉 Transformer。如果启用了 qk_norm(query 和 key 的归一化),它会重新初始化 Transformer 的块(blocks),使用提供的参数(如 embed_dim、num_heads、mlp_ratio 等)创建一系列的 Block 对象,并将它们组织成一个顺序容器。

class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
    def __init__(self, qk_norm=False, **kwargs):
        super(VisionTransformer, self).__init__(**kwargs)
        if qk_norm:
            del self.blocks
            embed_dim = kwargs['embed_dim']
            num_heads = kwargs['num_heads']
            mlp_ratio = kwargs['mlp_ratio']
            qkv_bias = kwargs['qkv_bias']
            depth = kwargs['depth']
            drop_rate = 0.
            attn_drop_rate = 0.
            drop_path_rate = 0.
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
            act_layer = nn.GELU
            self.blocks = nn.Sequential(*[
                Block(
                    dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm,
                    drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                    act_layer=act_layer)
                for i in range(depth)])

(9)函数vit_small_patch16_224定义了一个小型的 Vision Transformer 模型,其特征包括 16x16 的图像块大小、384 的嵌入维度、12 层深度、12 个注意力头和 4 的 MLP 比例。这个模型使用偏置的 qkv(query、key、value)以及带有 epsilon 值为 1e-6 的 LayerNorm 归一化层。该函数通过将这些参数传递给 VisionTransformer 类并返回实例化的模型来创建和返回 Vision Transformer 模型。

def vit_small_patch16_224(**kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

(10)函数vit_base_patch16_224定义了一个基础版的 Vision Transformer 模型,其特征包括 16x16 的图像块大小、768 的嵌入维度、12 层深度、12 个注意力头和 4 的 MLP 比例。这个模型使用偏置的 qkv(query、key、value)以及带有 epsilon 值为 1e-6 的 LayerNorm 归一化层。该函数通过将这些参数传递给 VisionTransformer 类并返回实例化的模型来创建和返回 Vision Transformer 模型。

def vit_base_patch16_224(**kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

(11)函数vit_large_patch16_224定义了一个大型的 Vision Transformer 模型,其特征包括 16x16 的图像块大小、1024 的嵌入维度、24 层深度、16 个注意力头和 4 的 MLP 比例。这个模型使用偏置的 qkv(query、key、value)以及带有 epsilon 值为 1e-6 的 LayerNorm 归一化层。该函数通过将这些参数传递给 VisionTransformer 类并返回实例化的模型来创建和返回 Vision Transformer 模型。

def vit_large_patch16_224(**kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model

(12)函数CLIP_VITS16定义了一个 CLIP 模型,使用了一个小型的 Vision Transformer (ViT) 模型作为视觉编码器。具体来说,函数首先创建一个具有 qk_norm 的 vit_small_patch16_224 模型,设置 num_classes=0。然后,这个 ViT 模型被传递给 CLIP 类进行实例化,并配置了嵌入维度、视觉宽度、上下文长度、词汇表大小、Transformer 的宽度、头数和层数等参数。函数最后返回这个配置好的 CLIP 模型。

def CLIP_VITS16(**kwargs):
    vision_model = vit_small_patch16_224(qk_norm=True, num_classes=0)
    model = CLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408,
        transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)

    return model

(13)函数CLIP_VITB16定义了一个 CLIP 模型,使用了一个基础规模的 Vision Transformer (ViT) 模型作为视觉编码器。具体来说,函数CLIP_VITB16首先创建一个具有 qk_norm 的 vit_base_patch16_224 模型,设置 num_classes=0。然后,这个 ViT 模型被传递给 CLIP 类进行实例化,并配置了嵌入维度、视觉宽度、上下文长度、词汇表大小、Transformer 的宽度、头数和层数等参数。函数最后返回这个配置好的 CLIP 模型。

def CLIP_VITB16(**kwargs):
    vision_model = vit_base_patch16_224(qk_norm=True, num_classes=0)
    model = CLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408,
        transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)

    return model

(14)函数CLIP_VITL16定义了一个 CLIP 模型,使用了一个大规模的 Vision Transformer (ViT) 模型作为视觉编码器。具体来说,函数CLIP_VITL16首先创建了一个具有 qk_norm 的 vit_large_patch16_224 模型,设置 num_classes=0。然后,这个 ViT 模型被传递给 CLIP 类进行实例化,并配置了嵌入维度、视觉宽度、上下文长度、词汇表大小、Transformer 的宽度、头数和层数等参数。函数最后返回这个配置好的 CLIP 模型。

def CLIP_VITL16(**kwargs):
    vision_model = vit_large_patch16_224(qk_norm=True, num_classes=0)
    model = CLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408,
        transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)

    return model

以下是使用bert-base-chinese训练实体识别模型的代码示例: 1. 准备数据集 首先,需要准备实体识别任务的数据集数据集应该包含标记好的实体标签,例如“B-PER”和“I-PER”表示人名实体的开始和内部标记。 2. 定义模型 定义一个使用bert-base-chinese预训练模型的实体识别模型,可以使用PyTorch和Transformers库: ```python import torch from transformers import BertForTokenClassification, BertTokenizer model = BertForTokenClassification.from_pretrained('bert-base-chinese', num_labels=5) tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') ``` 在这里,我们使用“num_labels”参数指定模型输出的标签数,即实体类别数。 3. 定义训练循环 接下来,我们定义训练循环来训练我们的模型: ```python from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers import AdamW, get_linear_schedule_with_warmup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) train_data = ... # 加载训练数据集 train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=32) optimizer = AdamW(model.parameters(), lr=5e-5, eps=1e-8) total_steps = len(train_dataloader) * 3 scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps) for epoch in range(3): for batch in train_dataloader: model.train() batch = tuple(t.to(device) for t in batch) inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3]} outputs = model(**inputs) loss = outputs[0] loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() ``` 在这里,我们使用随机采样器将训练数据集的数据随机分成小批次。我们使用AdamW优化器和带有线性学习率调度程序的预热来训练模型。在每个时期内,我们遍历每个小批次并计算损失和梯度,然后更新模型参数。 4. 评估模型 训练完成后,我们可以使用测试数据集评估模型的性能: ```python test_data = ... # 加载测试数据集 test_sampler = SequentialSampler(test_data) test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=32) model.eval() predictions = [] true_labels = [] for batch in test_dataloader: batch = tuple(t.to(device) for t in batch) inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[3]} with torch.no_grad(): outputs = model(**inputs) logits = outputs[1].detach().cpu().numpy() label_ids = inputs['labels'].cpu().numpy() predictions.extend([list(p) for p in np.argmax(logits, axis=2)]) true_labels.extend(label_ids) from sklearn.metrics import f1_score f1_score = f1_score(true_labels, predictions, average='weighted') print("F1 score:", f1_score) ``` 在这里,我们将测试数据集分成小批次,并使用模型的“eval”方法来计算预测标签。然后,我们使用f1_score度量来评估模型性能。 这就是使用bert-base-chinese训练实体识别模型的代码示例。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

码农三叔

感谢鼓励

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值