浅析原型学习(Prototype Learning)及代码实现|Prototypical Verbalizer for Prompt-based Few-shot Tuning

Preliminaries

看了一些博客123的内容

稍微总结一下原型学习的特点:

  1. Few-shot的场景下,这种方法训练的神经网络性能比单独用交叉熵损失训练的效果会好。
  2. 而原型学习的基本思路是基于嵌入(embedding)的思想,其中‘点’围绕每个类的单个原型表示聚类为每一个分类来创建一个原型表示(protoypical representation)。
    • 为了做到这一点,我们使用神经网络学习输入到嵌入空间的非线性映射,并将类的原型作为其支持集(support set)在嵌入空间中的平均值。然后,通过查找最近的类原型,对嵌入式查询点(query point)执行分类。

而本次是基于Prototypical Verbalizer for Prompt-based Few-shot TuningEnhanced Implicit Sentiment Understanding With Prototype Learning and Demonstration for Aspect-Based Sentiment Analysis这两篇论文来讨论的。接下来将从代码和论文本身的角度来阐述。

Prototype Learning

两篇论文都是用了两个学习任务,分别是instance-prototypeinstance-instance

本质上来说,就是需要考虑学习一个类别的原型表示,那么就需要定义一个原型,这里假设分类任务的类别为 C = { c 1 , …   c N } \mathcal{C}=\{c_1,\dots\, c_N\} C={c1,cN},而优化目标是基于InfoNCE的,因此需要定义两个空间向量之间的度量。

这里就采用了余弦相似度来计算,即 S ( v i , v j ) = v i ∣ ∣ v i ∣ ∣ ⋅ v j ∣ ∣ v j ∣ ∣ S(\mathbf{v}_i,\mathbf{v}_j)=\frac{\mathbf{v}_i}{||\mathbf{v}_i||} \cdot \frac{\mathbf{v}_j}{||\mathbf{v}_j||} S(vi,vj)=∣∣vi∣∣vi∣∣vj∣∣vj

# 我这里是先对已x和y做了归一化,当然根据自己的需要,来决定这个方法中是不是需要作归一化处理,可以利用torch.norm或者F.normalize都可以。
def similarity(self, x, y):
        return torch.matmul(x, y.T)

instance-prototype

先来看instance-prototype是怎么做的,很简单:
L p r o t o = − 1 N 2 K ∑ i , n log ⁡ ( e x p S ( v i n , c n ) ∑ n ′ e x p S ( v i n , c n ′ ) ) \mathcal{L}_{proto} = \frac{-1}{N^2K}\sum_{i,n}\log(\frac{expS(\mathbf{v}_i^n,\mathbf{c}_n)}{\sum_{n\prime}expS(\mathbf{v}_i^n,\mathbf{c}_{n\prime})}) Lproto=N2K1i,nlog(nexpS(vin,cn)expS(vin,cn))
其中 v i n \mathbf{v}_i^n vin表示表征 v i \mathbf{v}_i vi所属的类别为 n n n,代码表示为:

 # instance-prototype loss
        proto_logits = self.similarity(features, self.proto) # (batch_size, num_classes)
        loss = 0.
        assert torch.any(proto_logits.sum(1)!=0), "Zero division error"
        logits = - torch.gather(F.log_softmax(proto_logits, dim=1), 1, labels.view(batch_size,1)) # (batch_size, 1)
        loss += logits.sum() / (batch_size * self.num_classes * self.num_classes) 

其中 N N N表示为任务类别的数量, K K K表示为对应类别的样本数量(即N-way K-shot的含义)。

但是源码中的体现就是 K K K为batch_size, N N N是类别数量,这样地方要注意。

而这里的向量 v i \mathbf{v}_i vi表示为句子当中的token向量, c n \mathbf{c}_n cn表示为类别的原型表示。那么在代码中的体现就是:

w = torch.empty((self.num_classes, self.mid_dim))
nn.init.xavier_uniform_(w)
self.proto = nn.Parameter(w, requires_grad=True)

instance-instance

这部分和instance-prototype差不多,之间看公式
L i n s = − 1 N 2 K 2 ∑ n N ∑ i , j K log ⁡ ( e x p S ( v i n , v j n ) ∑ n ′ , j ′ e x p S ( v i n , v j ′ n ′ ) ) \mathcal{L}_{ins} = \frac{-1}{N^2K^2}\sum_n^N\sum_{i,j}^K\log(\frac{expS(\mathbf{v}_i^n,\mathbf{v}_j^n)}{\sum_{n\prime,j\prime}expS(\mathbf{v}_i^n,\mathbf{v}^{n\prime}_{j\prime})}) Lins=N2K21nNi,jKlog(n,jexpS(vin,vjn)expS(vin,vjn))
分子上的 ( v i n , v j n ) (\mathbf{v}_i^n,\mathbf{v}_j^n) (vin,vjn)表示为相同类别的instance pairs,且类别为 n n n

那么代码表示为:

# 计算mask
        labels = labels.view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device) # (batch_size, batch_size) 同类为1,不同类为0
        pos_mask = mask * (torch.ones_like(mask, device=device) - torch.eye(batch_size, device=device)) # (batch_size, batch_size) 同类且不是本身
        neg_mask = torch.ones_like(mask, device=device) - pos_mask - torch.eye(batch_size, device=device) # (batch_size, batch_size) 不同类且不是本身
        
        # 计算logits
        anchor_dot_contrast = self.similarity(features, features) # (batch_size, batch_size)
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        # 计算loss
        exp_logits = torch.exp(logits)
        log_prob = - torch.log( exp_logits * pos_mask / (exp_logits * neg_mask).sum(1, keepdim=True) + 1e-30)
        loss += log_prob.sum() / (batch_size * self.num_classes * batch_size * self.num_classes)

PrototypeLoss最终实现

class ProtoLoss(nn.Module):
    r"""
    The implementation of the prototypical verbalizer in `Prototypical Verbalizer for Prompt-based Few-shot Tuning <https://arxiv.org/pdf/2104.08691v1.pdf>`_ This class is inherited from the :obj:`Verbalizer` class.
    Args:
        classes (:obj:`List[Any]`): The classes (or labels) of the current task.
        label_words (:obj:`Union[List[str], List[List[str]], Dict[List[str]]]`, optional): The label words that are projected by the labels.
        prefix (:obj:`str`, optional): The prefix string of the verbalizer (used in PLMs like RoBERTa, which is sensitive to prefix space)
        mid_dim: (:obj:`int`, optional): The dimension of prototype embeddings. https://github.com/thunlp/OpenPrompt/blob/f6fb080ef755c37c01b7959e7560d007049510e8/openprompt/prompts/prototypical_verbalizer.py#L21
    """
    def __init__(self, num_labels, hidden_dim,
                 mid_dim: Optional[int] = 64,
                 label_words: Optional[Union[Sequence[str], Mapping[str, str]]] = None,
                 ):
        
        super().__init__()
        self.num_classes = num_labels
        self.hidden_dims = hidden_dim
        self.mid_dim = mid_dim
        self.head = nn.Linear(self.hidden_dims, self.mid_dim, bias=False)
        if label_words is not None: # use label words as an initialization
            self.label_words = label_words
        w = torch.empty((self.num_classes, self.mid_dim))
        nn.init.xavier_uniform_(w)
        self.proto = nn.Parameter(w, requires_grad=True)

    def similarity(self, x, y):
        return torch.matmul(x, y.T)
    
    def my_pcl_loss(self, features, labels=None):
        """
        features: (batch_size, hidden_dim)
        labels: (batch_size,)
        """
        device = features.device
        batch_size = features.shape[0]
        features = F.normalize(features,dim=-1)

        # instance-prototype loss
        proto_logits = self.similarity(features, self.proto) # (batch_size, num_classes)
        loss = 0.
        assert torch.any(proto_logits.sum(1)!=0), "Zero division error"
        logits = - torch.gather(F.log_softmax(proto_logits, dim=1), 1, labels.view(batch_size,1)) # (batch_size, 1)
        loss += logits.sum() / (batch_size * self.num_classes * self.num_classes) 

        # instance-instance loss
        # 计算mask
        labels = labels.view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device) # (batch_size, batch_size) 同类为1,不同类为0
        pos_mask = mask * (torch.ones_like(mask, device=device) - torch.eye(batch_size, device=device)) # (batch_size, batch_size) 同类且不是本身
        neg_mask = torch.ones_like(mask, device=device) - pos_mask - torch.eye(batch_size, device=device) # (batch_size, batch_size) 不同类且不是本身
        
        # 计算logits
        anchor_dot_contrast = self.similarity(features, features) # (batch_size, batch_size)
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        # 计算loss
        exp_logits = torch.exp(logits)
        log_prob = - torch.log( exp_logits * pos_mask / (exp_logits * neg_mask).sum(1, keepdim=True) + 1e-30)
        loss += log_prob.sum() / (batch_size * self.num_classes * batch_size * self.num_classes)

        return loss, proto_logits

    def forward(self, features, **kwargs):
        """
        outputs: (batch_size, max_seq_length, hidden_dim)
        Returns:
            :obj:`torch.Tensor`: The final logits of the label words.
        """
        outputs = self.extract_at_mask(features, kwargs["loss_ids"]) # (batch_size, hidden_dim)
        outputs = self.head(outputs)
        loss, logits_proto = self.my_pcl_loss(outputs, kwargs["labels"])
        return loss, logits_proto
    
    def extract_at_mask(self,
                       outputs: torch.Tensor,
                       mask_idx: torch.Tensor,):
        r"""Get outputs at all <mask> token
        E.g., project the logits of shape
        (``batch_size``, ``max_seq_length``, ``vocab_size/hidden_dim``)
        into logits of shape (if num_mask_token > 1)
        (``batch_size``, ``num_mask_token``, ``vocab_size/hidden_dim``)
        or into logits of shape (if ``num_mask_token`` = 1)
        (``batch_size``, ``vocab_size/hidden_dim``).

        Args:
            outputs (:obj:`torch.Tensor`): The original outputs (maybe process by verbalizer's
                 `gather_outputs` before) etc. of the whole sequence.
            batch (:obj:`Union[Dict, InputFeatures]`): The original batch

        Returns:
            :obj:`torch.Tensor`: The extracted outputs of ``<mask>`` tokens.
        https://github.com/thunlp/OpenPrompt/blame/f6fb080ef755c37c01b7959e7560d007049510e8/openprompt/pipeline_base.py#L259
        """
        outputs = outputs[torch.where(mask_idx>0)] #
        outputs = outputs.view(outputs.shape[0], -1, outputs.shape[1])
        if outputs.shape[1] == 1:
            outputs = outputs.squeeze(1)
        return outputs

与源代码的对比说明

由于Enhanced Implicit Sentiment Understanding With Prototype Learning and Demonstration for Aspect-Based Sentiment Analysis这篇论文并未公开源码,我只能根据Prototypical Verbalizer for Prompt-based Few-shot Tuning这篇论文来尝试复现。需要说明的是,我实现的方法和原文是由出入的,虽然我尽量是根据公式来实现的,并考虑了源码,但是对源码有一些不解的地方,并根据自己的想法进行了修改

但是这篇论文是基于OpenPrompt这个框架来的,所以需要定位到程序的运行入口在哪里,我找了半天,感觉是在项目的experiments/cli.py这个文件中。而这个文件是需要读取配置文件的,那么原型学习的配置文件是在experiments/classification_proto_verbalizer.yaml中。

接下来就可以根据程序的运行,自己用脑子debug吧……

最终根据我自己的脑动debug可以定位到程序是执行openprompt/prompts/prototypical_verbalizer.py中的train_proto这个方法的。而其是在openprompt/protoverb_trainer.py中的fit方法中被调用的。下面是train_proto的定义

 def train_proto(self, model, dataloader, device):
        model.eval()
        embeds = [[] for _ in range(self.num_classes)]
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                batch = batch.to("cuda:{}".format(device)).to_dict()
                outputs = model.prompt_model(batch)
                hidden, _ = self.gather_outputs(outputs)
                outputs_at_mask = model.extract_at_mask(hidden, batch)
                for j in range(len(outputs_at_mask)):
                    label = batch['label'][j]
                    embeds[label].append(outputs_at_mask[j])
        embeds = [torch.stack(e) for e in embeds]
        embeds = torch.stack(embeds)

        instance_mean = embeds.mean(1)
        loss = 0.
        for epoch in range(self.epochs):
            x = self.head(embeds)
            self.optimizer.zero_grad()
            loss = self.pcl_loss(x)
            loss.backward()
            self.optimizer.step()
        logger.info("Total epoch: {}. ProtoVerb loss: {}".format(self.epochs, loss))
        self.trained = True

可以看到其是调用了pcl_loss这个方法,而下面是 pcl_loss的具体实现,由于我没有对整个程序作debug,这里我就进行猜测v_ins的维度应该是三维的,具体为( num_classes, batch_size, hidden_dim),下面也包括了我的标注内容以及一些困惑:

首先是 instance-prototype loss:

 # instance-prototype loss
sim_mat = torch.exp(self.sim(v_ins, self.proto)) # 维度应该为(num_classes, batch_size, num_classes)
num = sim_mat.shape[1] # 获取batch_size
loss = 0.
for i in range(num):
    pos_score = torch.diag(sim_mat[:,i,:]) #
    neg_score = (sim_mat[:,i,:].sum(1) - pos_score)
    loss += - torch.log(pos_score / (pos_score + neg_score)).sum()
loss = loss / (num * self.num_classes * self.num_classes)

困惑

  1. 对于这个loss的实现的感觉这不就变成了loss += - torch.log(pos_score / (sim_mat[:,i,:].sum(1))).sum()了吗。

  2. pos_score = torch.diag(sim_mat[:,i,:])表示每个样本的在每个类别上的损失。但是从公式来看 v i n \mathbf{v}_i^n vin表示表征 v i \mathbf{v}_i vi所属的类别为 n n n,应该只需要考虑样本i所属的类别就可以了呀?难道是多标签问题?

    Instance-prototype loss实现方式及 v i n v_i^n vin含义:

    • 首先定义了一个基于InfoNCE估计器的损失函数。对于实例 - 原型损失(instance-prototype loss),其目标是使原型向量位于其对应类实例的中心位置。具体来说,通过计算原型向量与同一类实例之间的相似性得分,并与原型向量和其他类实例的相似性得分进行对比,来优化原型向量。 -

    • 损失函数公式为
      L p r o t o = − 1 N 2 K ∑ i , n l o g e x p S ( v i n , c n ) ∑ n ′ e x p S ( v i n , c n ′ ) \mathcal{L}_{proto }=\frac{-1}{N^{2} K} \sum_{i, n} log \frac{exp S\left(v_{i}^{n}, c_{n}\right)}{\sum_{n \prime} exp S\left(v_{i}^{n}, c_{n \prime}\right)} Lproto=N2K1i,nlognexpS(vin,cn)expS(vin,cn)

    其中 v i n v_{i}^{n} vin是属于类(n)的实例表示, c n c_{n} cn是类(n)的原型向量,该公式通过最小化这个损失函数,使得原型向量 c n c_{n} cn与所属类别为 n n n的实例 v i n v_{i}^{n} vin之间的相似度得分更高,而与其他类的原型向量 c n ′ c_{n'} cn之间的相似度得分更低。

    • v i n v_{i}^{n} vin:属于类(n)的第(i)个实例的向量表示。在论文中,给定一个用模板包装的训练文本(x),通过取[MASK]标记的最后一层隐藏状态h{[MASK]}作为初始表示,然后经过一个线性编码器,得到实例表示v = E(x) = Wh{[MASK]},这里的v就是 v i n v_{i}^{n} vin中的v,i和n分别表示实例的序号和所属类别的序号。
  3. 论文中的exp体现在了什么地方呢?

因此我根据论文中的代码进行了改写:

# 这是我用非for循环的形式来实现的
proto_logits = self.similarity(features, self.proto) # (batch_size, num_classes)
loss = 0.
assert torch.any(proto_logits.sum(1)!=0), "Zero division error"
logits = - torch.gather(F.log_softmax(proto_logits, dim=1), 1, labels.view(batch_size,1)) # (batch_size, 1)
loss += logits.sum() / (batch_size * self.num_classes * self.num_classes) 

其次是instance-instance loss:

# instance-instance loss
loss_ins = 0.
for i in range(v_ins.shape[0]): # 遍历类别,对应于论文当中的公式5中的最外层的sum
    sim_instance = torch.exp(self.sim(v_ins, v_ins[i])) # 公式5中实例之间的相似度,对应的维度应为(num_classes, batch_size, batch_size)
    pos_ins = sim_instance[i] #公式5中分子计算exp S(v^n_i, v^n_j), 即找到与实例i类别相同的其他实例
    neg_ins = (sim_instance.sum(0) - pos_ins).sum(0)
    loss_ins += - torch.log(pos_ins / (pos_ins + neg_ins)).sum()
loss_ins = loss_ins / (num * self.num_classes * num * self.num_classes)
loss = loss + loss_ins

return loss

这里则是我使用矩阵的计算,放弃了for循环的形式:

下面是我对 L i n s \mathcal{L}_{ins} Lins的重构表示:
L i n s = − 1 N 2 K 2 ∑ n N ∑ i , j K log ⁡ ( e x p S ( v i n , v j n ) ∑ n ′ , j ′ e x p S ( v i n , v j ′ n ′ ) ) = − 1 N 2 K 2 ∑ i , j K ∑ n N log ⁡ ( e x p S ( v i n , v j n ) ∑ n ′ , j ′ e x p S ( v i n , v j ′ n ′ ) ) ⇒ − 1 B 2 K 2 ∑ k ∈ B ∑ h ≠ k , h ∈ B , y h = y k log ⁡ ( e x p S ( v k , v h ) ∑ h ′ ∈ B , y h ′ ≠ y k e x p S ( v k , v k ′ ) ) \mathcal{L}_{ins} = \frac{-1}{N^2K^2}\sum_n^N\sum_{i,j}^K\log(\frac{expS(\mathbf{v}_i^n,\mathbf{v}_j^n)}{\sum_{n\prime,j\prime}expS(\mathbf{v}_i^n,\mathbf{v}^{n\prime}_{j\prime})})\\ =\frac{-1}{N^2K^2}\sum_{i,j}^K\sum_n^N\log(\frac{expS(\mathbf{v}_i^n,\mathbf{v}_j^n)}{\sum_{n\prime,j\prime}expS(\mathbf{v}_i^n,\mathbf{v}^{n\prime}_{j\prime})}) \\ \Rightarrow \frac{-1}{B^2K^2}\sum_{k \in B}\sum_{h \neq k, h \in B,y_h=y_k} \log(\frac{expS(\mathbf{v}_k,\mathbf{v}_h)}{\sum_{h\prime \in B, y_{h\prime} \neq y_k}expS(\mathbf{v}_k,\mathbf{v}_{k\prime})}) \\ Lins=N2K21nNi,jKlog(n,jexpS(vin,vjn)expS(vin,vjn))=N2K21i,jKnNlog(n,jexpS(vin,vjn)expS(vin,vjn))B2K21kBh=k,hB,yh=yklog(hB,yh=ykexpS(vk,vk)expS(vk,vh))
因为公式分子上的 ( v i n , v j n ) (\mathbf{v}_i^n,\mathbf{v}_j^n) (vin,vjn)表示为相同类别的instance pairs, 而这里假设的是一个Batch B B B的数据量为N*K,batch里面的每个样本只属于一个类别的。所以改成了新的公式,并依据新的公式表达来编写代码。

# instance-instance loss
# 计算mask
labels = labels.view(-1, 1)
mask = torch.eq(labels, labels.T).float().to(device) # (batch_size, batch_size) 同类为1,不同类为0
pos_mask = mask * (torch.ones_like(mask, device=device) - torch.eye(batch_size, device=device)) # (batch_size, batch_size) 同类且不是本身
neg_mask = torch.ones_like(mask, device=device) - pos_mask - torch.eye(batch_size, device=device) # (batch_size, batch_size) 不同类且不是本身

# 计算logits
anchor_dot_contrast = self.similarity(features, features) # (batch_size, batch_size)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# 计算loss
exp_logits = torch.exp(logits)
log_prob = - torch.log( exp_logits * pos_mask / (exp_logits * neg_mask).sum(1, keepdim=True) + 1e-30)
loss += log_prob.sum() / (batch_size * self.num_classes * batch_size * self.num_classes)

分析

我与原文的代码不同点在于v_ins的使用,其维度是三维的,我不清楚这个三维(我猜测是( num_classes, batch_size, hidden_dim))是怎么得到的呢?

我目前只能定位到是在train_proto这个方法中,v_ins主要是来自于embeds中。

def train_proto(self, model, dataloader, device):
        model.eval()
        embeds = [[] for _ in range(self.num_classes)]
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                batch = batch.to("cuda:{}".format(device)).to_dict()
                outputs = model.prompt_model(batch)
                hidden, _ = self.gather_outputs(outputs)
                outputs_at_mask = model.extract_at_mask(hidden, batch)
                for j in range(len(outputs_at_mask)):
                    label = batch['label'][j]
                    embeds[label].append(outputs_at_mask[j])
        embeds = [torch.stack(e) for e in embeds] 
        embeds = torch.stack(embeds) # 至此是怎么搞出三维的呢????

        instance_mean = embeds.mean(1)
        loss = 0.
        for epoch in range(self.epochs):
            x = self.head(embeds)
            self.optimizer.zero_grad()
            loss = self.pcl_loss(x)
            loss.backward()
            self.optimizer.step()
        logger.info("Total epoch: {}. ProtoVerb loss: {}".format(self.epochs, loss))
        self.trained = True

还是能力不够哇,不知道怎么算出来的。
如有错误,还请指点一二哇。感谢感谢。

References


  1. 1 学习笔记–原型网络(Prototypical Network) - Huang的文章 - 知乎 ↩︎

  2. 元学习——原型网络(Prototypical Networks) ↩︎

  3. prototype-based learning algorithm(原型学习) ↩︎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值