Contrastive Search Decoding——一种对比搜索解码文本生成算法

目录

一、contrastive search decoding

二、代码实现理解和实验

1、代码走读

2、生成效果展示

3、方案的缺陷


最近在做文本生成相关的任务,调研的时候刷到一篇文本生成的论文:

《A Contrastive Framework for Neural Text Generation》

它认为GPT2生成模型再生成的token具有各异向性,使得token之间的相似性非常接近没有很好的区分度,最后解码的时候造成了文本重复——text degeneration;因此论文提出了一种新的训练策略(SimCTG)+解码算法(contrastive search),在多语言任务和实际的工业场景中进行人工评测,很显著的提升了文本生成的质量。关于该论文提出的text degeneration的原因知乎上有很多大佬和论文作者进行讨论和剖析,最后得出的结论是text degeneration的原因并不是SIMCTG提出的Contrastive Training,它并不能保证表征各向同质性,之所以在文本生成的质量上(少无意义的重复)有实实在在的提升,完全来自于新提出的解码策略——contrastive search decoding。既然这么有效的解码策略,是应该好好学习一下。

一、contrastive search decoding

这是一种非topK、topP以及BeamSearch的解码策略,感觉非常有意思。其核心思想就是对比——把当前要生成的token和已经生成的所有token做相似度计算,得到最大的相似度值;然后使得该token的概率与最大的相似度值的差值最大化的那个token就是我们要生成的token;具体的公式如下:

 V(k)是指token在模型输出的分布中top_k个最可能的结果,论文中提出K值通常设置3~10。看完公式觉得思想很简单,一下子就理解了公式要表达的思想,但是这里还是有几个值得注意的地方:

1、如何高效的得到当前token的embedding,也就是hv;以及如何得到h1,.....ht-1(已经生成的token的embedding)

2、如何高效的计算当前token的embedding和之前所有文本的embedding的相似度的最大值

3、如何计算整体上的最大值得到V(k)最佳的v

在问题1已经解决的情况下,2和3问题比较好解决,直接采用矩阵计算使用GPU并行计算,就可以很好的解决计算的效率问题;第一个问题理解起来有点点难,对于不太熟悉GPT2模型的人来说,确实不太好理解。本人再阅读起实现源码后,和作者沟通后,再加上对GPT2生成流程的理解后,才完全理解到底应该怎么求hv的。

 contrastive search decoding大体上的解码流程如上图所示,当前轮次文本输入gpt2模型,使用hm得到新的k个候选生成tokens;然后把这些tokens和之前的文本拼接起来输入到下一轮模型,得到hm+1。这里的hm+1就是前面说的上一轮应该生成的token的embedding,通过解码公式的计算,选出最佳的hm+1也就得到了tm+1——当前轮最佳的那个token。按照上述流程循坏下去就可以得到生成一个句子了。

二、代码实现理解和实验

1、代码走读

上面的核心思想简单的分析了,下面看看如何具体的使用代码实现,先上整体的实现代码,然后再慢慢解析:

def contrastive_search_decode(curr_input_tensor,attention_mask,tokenizer):
    """
    对比搜索解码策略
    """
    alpha = 0.5
    beam_width = 5
    generated = [item for item in curr_input_tensor.tolist()]
    past_key_values = None

    max_length = 64 + curr_input_tensor.shape[1]
    stop = False

    with torch.no_grad():
        for index in range(max_length):
            if index == 0:
                inputs = prepare_inputs_for_generation(curr_input_tensor, attention_mask, past=past_key_values)
                output = model(**inputs,return_dict = True,use_cache=True,output_hidden_states=True)
                past_key_values = output.past_key_values
                last_hidden_states = output.hidden_states[-1]  # [B, S, E]
                logit_for_next_step = output.logits[:, -1, :]  # [B, V]

            bsz, seqlen, embed_dim = last_hidden_states.size()

            next_probs = F.softmax(logit_for_next_step, dim=-1)
            _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K]
            top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)  # [B, K]

            # compute new hidden
            past_key_values = enlarge_past_key_values(past_key_values, beam_width)
            output = model(
                input_ids=top_k_ids.view(-1, 1),
                attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
                past_key_values=past_key_values,
                output_hidden_states=True,
                use_cache=True,
            )
            # past_key_values是一个二维list;里层list元素是tensor
            past_key_values = output.past_key_values
            logits = output.logits[:, -1, :]  # [B*K, V]
            next_hidden = output.hidden_states[-1]  # [B*K, 1, E]
            context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width,seqlen,embed_dim)  # [B*K, S, E]

            selected_idx = ranking_fast(
                context_hidden,
                next_hidden,
                top_k_probs,  # [B, K]
                alpha,
                beam_width,
            )  # [B]

            # prepare for the next step
            next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)  # [B, 1]
            temp = torch.split(next_hidden.squeeze(dim=1), beam_width)
            next_hidden = torch.stack(temp)  # [B, K, E]
            next_hidden = next_hidden[range(bsz), selected_idx, :]  # [B, E]
            last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)  # [B, S+1, E]
            past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
            temp = torch.split(logits, beam_width)
            logit_for_next_step = torch.stack(temp)[range(bsz), selected_idx, :]  # [B, V]

            tokens = next_id.squeeze(dim=-1).tolist()
            for idx, t in enumerate(tokens):
                generated[idx].append(t)

            for token in tokens:
                if token == 102:
                    stop = True
                    break
            if stop:
                break

    res = tokenizer.batch_decode(generated, skip_special_tokens=True)

说说几个细节

a、past_key_values扩充和压缩

由于每次需要传入past_key_values加快模型的推理速度,并且要在top_k中得到最佳的那个token,因此需要把K个token都要纳入计算中,为了能够矩阵计算需要把每次输入都扩充K倍:

past_key_values扩充

def enlarge_past_key_values(past_key_values, beam_width):
    # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
    new_key_values = []
    for layer in past_key_values:
        items = []
        for item in layer:
            # item is the key and value matrix
            bsz, num_head, seq_len, esz = item.size()
            item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz)    # [bsz*beam, num_head, seq_len, esz]
            items.append(item)
        new_key_values.append(items)
    return new_key_values

past_key_values中每个tensor的维度变化[B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]

past_key_values压缩

def select_past_key_values(past_key_values, beam_width, selected_idx):
    '''select_idx: [B]'''
    new_key_values = []
    for layer in past_key_values:
        items = []
        for item in layer:
            bsz_and_beam, num_head, seq_len, esz = item.size()
            bsz = int(bsz_and_beam//beam_width)
            temp = torch.split(item, beam_width, dim=0)
            item = torch.stack(temp)    # [B, K, num_head, seq_len, esz]
            item = item[range(bsz), selected_idx, :, :, :]   # [B, num_head, seq_len, esz]
            items.append(item)
        new_key_values.append(items)
    return new_key_values

past_key_values中每个tensor的维度从[B*K, num_head, seq_len, esz]变回到[B, num_head, seq_len, esz]

b、当前token和之前所有token的相似度并行计算

def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width):
    '''
        context_hidden: bsz*beam x seqlen x embed_dim
        next_hidden: bsz*beam x 1 x embed_dim
        next_top_k_probs: bsz x beam
    '''
    _, context_len, embed_dim = context_hidden.size()
    norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
    norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
    cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)    # [B*K, S]
    scores, _ = torch.max(cosine_matrix, dim=-1)    # [B*K]
    next_top_k_probs = next_top_k_probs.view(-1)    # [B*K]
    scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
    temp = torch.split(scores, beam_width)
    scores = torch.stack(temp)    # [B, K]
    selected_idx = scores.max(dim=-1)[1]    # [B]
    return selected_idx

需要注意到这里的torch.matmul()的计算

context_hidden:[B*K,S,D]

next_hidden:[B*K,1,D]

需要计算batch中每一条数据(每个token的embedding)和之前所有token的embedding的cos相似度

torch.matmul([B*K,S,D],B*K,1,D].T(2,1))=torch.matmul([B*K,S,D],B*K,D,1])=[B*K,S,1]

然后再求最大的那个score的index即可

2、生成效果展示

 

 生成的语句还是比较流畅的,重复性得到改善,逻辑性这个是模型本身的问题;但是具体比之前采用beamsearch + sample效果具体能好多少,这边我没有做太多的验证,需要上线使用机器人聊一段时间才知道,不过beamsearch + sample在实际使用的时候就算加上了重复惩罚系数,生成的时候也会有部分重复的,生成例子:

现在财务下班了,财务下班了,明天下午到账

不是,我们不是一个公司的,不是一个公司的

好的,那我给您改一下。那我这边给您改一下

[让我看看][让我看看][让我看看][让我看看]

代理点:506经办200019经办200019经办200019经办

2000块钱,2000块钱,2000块,2000块钱,20002000块钱,2000200020

真实的contrastive search decoding效果,还有待观察,不过目前简单的测试几条来看生成还可以。

3、方案的缺陷

一般而言,我们都要求生成的句子具有多样性——有不同的生成,contrastive search decoding是一个确定性方案,每次只能生成固定的结果。这里作者有提出一个比较合适的方法:

就是先使用beamsearch + sample等方法生成部分句子,然后再使用contrastive search decoding对生成的句子进行补齐。

具体的实现不是特别困难,这里就不实现了。

还有一种方法,实现上比较麻烦,我也提一下思想:就是那个公式中选择v的时候,不选最大的那一个,多选择几个,但是要小于K值。

公式中的argmax 换成 top_n,n取2、3、4这种比K/2小的值感觉比较合适。

参考文章:

如何评价剑桥,腾讯, DeepMind以及港大团队新作 SimCTG ? - 王琰的回答 - 知乎

2022 - A Contrastive Framework for Neural Text Generation

  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
自监督对比学习是一种无监督学习方法,旨在通过将数据样本与其在相同任务下的变体进行比较来进行特征学习。其核心思想是将一个样本与自身的不同变体进行对比,以推动特征的区分度增加。 在自监督对比学习中,通常使用一种转换函数对输入样本进行变换,生成多个变体。这些变换可以是图像旋转、裁剪、亮度调整等,也可以是对文本数据进行掩码、重排等操作。对于每个输入样本及其变体,模型将利用一个对比损失函数来度量它们之间的相似性。 通过自监督对比学习,模型会学习到一组鲁棒的特征表示。这些特征不仅能够区分同一样本与其变体,还能够区分不同样本之间的差异。通过不同样本之间的对比学习,模型可以学习到更加丰富的语义信息,提高数据的表征能力。 自监督对比学习在计算机视觉和自然语言处理等领域得到了广泛的应用。例如,在图像领域,可以利用自监督对比学习来学习图像中的局部特征、形状和纹理等信息。而在自然语言处理领域,可以通过对文本进行掩码、重排等方式来进行自监督对比学习,以学习词语、句子和文档的语义表示。 自监督对比学习的窥探给了我们一个更好的方式,通过无监督学习方法来解决许多现实世界中的问题。它为我们提供了一种从大规模数据中学习有用表示的方式,提高了学习算法的效率和泛化性能。通过进一步的研究和发展,自监督对比学习注定将在更多的领域中发挥重要的作用。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值