双塔模型dssm实践

文章介绍了DSSM(DeepStructuredSemanticModel)双塔模型在向量召回任务中的使用,通过预训练的BERT模型作为编码器,计算query和doc向量的余弦相似度。模型在训练时利用AdamW优化器和线性学习率调度策略,评估指标包括准确率、精确率、召回率和F1分数。在实际线上环境中,使用FAISS等工具构建向量检索索引,实现高效的查询召回。
摘要由CSDN通过智能技术生成

最近在学习向量召回,向量召回不得不用到dssm双塔模型,双塔模型的原理非常简单,就是用两个任务塔,一个是query侧的query任务塔,另一个是doc侧的doc任务塔,任务塔向上抽象形成verctor隐向量后,用cosin相似度度计算两个向量之间的相似程度,进而让query和doc的向量尽量的相近。
模型整体的结构如下所示,线上工程使用的时候,是将doc的隐向量进行离线保存,用faiss类似工具形成线上的向量检索索引,然后线上用户query查询的时候,就可以通过线上部署的query侧任务塔,实时的生成query侧的隐向量,然后通过query隐向量去查询faiss中相似的doc向量结果,达到query语义召回的目的。
在这里插入图片描述

代码:模型中的encoder是可以用任何nlp编码模型实现,比如说可以用预训练的bert实现。

class DSSM(nn.Module):
    """
    DSSM(Deep Structured Semantic Model) 模型实现, 采用cos值计算向量相似度, 精度稍低, 但计算速度快。
    Paper Reference: https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf

    Args:
        nn (_type_): _description_
    """

    def __init__(self, encoder, dropout=None):
        """
        init func.

        Args:
            encoder (transformers.AutoModel): backbone, 默认使用 ernie 3.0
            dropout (float): dropout.
        """
        super().__init__()
        self.encoder = encoder
        hidden_size = 768
        self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)

    def forward(
        self,
        input_ids: torch.tensor,
        token_type_ids: torch.tensor,
        attention_mask: torch.tensor
    ) -> torch.tensor:
        """
        forward 函数,输入单句子,获得单句子的embedding。

        Args:
            input_ids (torch.LongTensor): (batch, seq_len)
            token_type_ids (torch.LongTensor): (batch, seq_len)
            attention_mask (torch.LongTensor): (batch, seq_len)

        Returns:
            torch.tensor: embedding -> (batch, hidden_size)
        """
        embedding = self.encoder(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask
            )["pooler_output"]                                  # (batch, hidden_size)
        return embedding

    def get_similarity(
        self,
        query_input_ids: torch.tensor,
        query_token_type_ids: torch.tensor,
        query_attention_mask: torch.tensor,
        doc_input_ids: torch.tensor,
        doc_token_type_ids: torch.tensor,
        doc_attention_mask: torch.tensor
    ) -> torch.tensor:
        """
        输入query和doc的向量,返回query和doc两个向量的余弦相似度。

        Args:
            query_input_ids (torch.LongTensor): (batch, seq_len)
            query_token_type_ids (torch.LongTensor): (batch, seq_len)
            query_attention_mask (torch.LongTensor): (batch, seq_len)
            doc_input_ids (torch.LongTensor): (batch, seq_len)
            doc_token_type_ids (torch.LongTensor): (batch, seq_len)
            doc_attention_mask (torch.LongTensor): (batch, seq_len)

        Returns:
            torch.tensor: (batch, 1)
        """
        print('query_input_ids=', query_input_ids, query_input_ids.shape)
        # print('query_token_type_ids=', query_token_type_ids)
        # print('query_attention_mask=', query_attention_mask)
        query_embedding = self.encoder(
            input_ids=query_input_ids,
            token_type_ids=query_token_type_ids,
            attention_mask=query_attention_mask
        )["pooler_output"]                                 # (batch, hidden_size)
        print('query_embedding=', query_embedding, query_embedding.shape)
        query_embedding = self.dropout(query_embedding)

        doc_embedding = self.encoder(
            input_ids=doc_input_ids,
            token_type_ids=doc_token_type_ids,
            attention_mask=doc_attention_mask
        )["pooler_output"]                                  # (batch, hidden_size)
        doc_embedding = self.dropout(doc_embedding)

        similarity = nn.functional.cosine_similarity(query_embedding, doc_embedding)
        return similarity

train代码:

# !/usr/bin/env python3

import os
import time
import argparse
from functools import partial

import torch
from torch.utils.data import DataLoader
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, default_data_collator, get_scheduler

from model import DSSM
from utils import convert_dssm_example
from iTrainingLogger import iSummaryWriter


parser = argparse.ArgumentParser()
parser.add_argument("--model", default='bert-base-chinese', type=str, help="backbone of encoder.")
parser.add_argument("--train_path", default='./data/dssm_data/train.csv', type=str, help="The path of train set.")
parser.add_argument("--dev_path", default='./data/dssm_data/train.csv', type=str, help="The path of dev set.")
parser.add_argument("--save_dir", default="./checkpoints", type=str, required=False, help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--max_seq_len", default=512, type=int,help="The maximum total input sequence length after tokenization. Sequences longer "
    "than this will be truncated, sequences shorter will be padded.", )
parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.", )
parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--num_train_epochs", default=10, type=int, help="Total number of training epochs to perform.")
parser.add_argument("--warmup_ratio", default=0.06, type=float, help="Linear warmup over warmup_ratio * total_steps.")
parser.add_argument("--valid_steps", default=200, type=int, required=False, help="evaluate frequecny.")
parser.add_argument("--logging_steps", default=10, type=int, help="log interval.")
parser.add_argument('--device', default="cpu", help="Select which device to train model, defaults to gpu.")
parser.add_argument("--img_log_dir", default='logs', type=str, help="Logging image path.")
parser.add_argument("--img_log_name", default='Model Performance', type=str, help="Logging image file name.")
args = parser.parse_args()

writer = iSummaryWriter(log_path=args.img_log_dir, log_name=args.img_log_name)


def evaluate_model(model, metric, data_loader, global_step):
    """
    在测试集上评估当前模型的训练效果。

    Args:
        model: 当前模型
        metric: 评估指标类(metric)
        data_loader: 测试集的dataloader
        global_step: 当前训练步数
    """
    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(data_loader):
            logits = model.get_similarity(query_input_ids=batch['query_input_ids'].to(args.device),
                            query_token_type_ids=batch['query_token_type_ids'].to(args.device),
                            query_attention_mask=batch['query_attention_mask'].to(args.device),
                            doc_input_ids=batch['doc_input_ids'].to(args.device),
                            doc_token_type_ids=batch['doc_token_type_ids'].to(args.device),
                            doc_attention_mask=batch['doc_attention_mask'].to(args.device))
            logits[logits>=0.5] = 1
            logits[logits<0.5] = 0
            metric.add_batch(predictions=logits.to(torch.int), references=batch["labels"])
    eval_metric = metric.compute()
    model.train()
    return eval_metric['accuracy'], eval_metric['precision'], eval_metric['recall'], eval_metric['f1']


def train():
    encoder = AutoModel.from_pretrained(args.model)
    model = DSSM(encoder)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    dataset = load_dataset('text', data_files={'train': args.train_path,
                                                'dev': args.dev_path})    
    print(dataset)
    convert_func = partial(convert_dssm_example, tokenizer=tokenizer, max_seq_len=args.max_seq_len)
    dataset = dataset.map(convert_func, batched=True)
    
    train_dataset = dataset["train"]
    eval_dataset = dataset["dev"]
    train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.batch_size)
    eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=args.batch_size)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=5e-5)
    model.to(args.device)

    # 根据训练轮数计算最大训练步数,以便于scheduler动态调整lr
    num_update_steps_per_epoch = len(train_dataloader)
    max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    warm_steps = int(args.warmup_ratio * max_train_steps)
    lr_scheduler = get_scheduler(
        name='linear',
        optimizer=optimizer,
        num_warmup_steps=warm_steps,
        num_training_steps=max_train_steps,
    )

    loss_list = []
    metric = evaluate.combine(["accuracy", "f1", "precision", "recall"])
    criterion = torch.nn.MSELoss()
    tic_train = time.time()
    global_step, best_f1 = 0, 0
    for epoch in range(1, args.num_train_epochs+1):
        for batch in train_dataloader:
            logits = model.get_similarity(query_input_ids=batch['query_input_ids'].to(args.device),
                            query_token_type_ids=batch['query_token_type_ids'].to(args.device),
                            query_attention_mask=batch['query_attention_mask'].to(args.device),
                            doc_input_ids=batch['doc_input_ids'].to(args.device),
                            doc_token_type_ids=batch['doc_token_type_ids'].to(args.device),
                            doc_attention_mask=batch['doc_attention_mask'].to(args.device))
            labels = batch['labels'].to(torch.float).to(args.device)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            loss_list.append(float(loss.cpu().detach()))
            
            if global_step % args.logging_steps == 0:
                time_diff = time.time() - tic_train
                loss_avg = sum(loss_list) / len(loss_list)
                writer.add_scalar('train/train_loss', loss_avg, global_step)
                print("global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s"
                        % (global_step, epoch, loss_avg, args.logging_steps / time_diff))
                tic_train = time.time()

            if global_step % args.valid_steps == 0:
                cur_save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
                if not os.path.exists(cur_save_dir):
                    os.makedirs(cur_save_dir)
                torch.save(model, os.path.join(cur_save_dir, 'model.pt'))
                tokenizer.save_pretrained(cur_save_dir)

                acc, precision, recall, f1 = evaluate_model(model, metric, eval_dataloader, global_step)
                writer.add_scalar('eval/accuracy', acc, global_step)
                writer.add_scalar('eval/precision', precision, global_step)
                writer.add_scalar('eval/recall', recall, global_step)
                writer.add_scalar('eval/f1', f1, global_step)
                writer.record()

                print("Evaluation precision: %.5f, recall: %.5f, F1: %.5f" % (precision, recall, f1))
                if f1 > best_f1:
                    print(
                        f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}"
                    )
                    best_f1 = f1
                    cur_save_dir = os.path.join(args.save_dir, "model_best")
                    if not os.path.exists(cur_save_dir):
                        os.makedirs(cur_save_dir)
                    torch.save(model, os.path.join(cur_save_dir, 'model.pt'))
                    tokenizer.save_pretrained(cur_save_dir)
                tic_train = time.time()
            
            global_step += 1


if __name__ == '__main__':
    train()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WitsMakeMen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值