transformer中SequenceSummary:分类好帮手

1. 概述

SequenceSummary主要作用是从各种预训练模型的hidden state中抽取出向量来分类。例如在BERT中我们通常选取句子中第一字也就是[CLS]对应的向量来分类,而GPT2中我们会选择最后一个字来做分类。当然,在GPT2中不能直接取最后一个字,比如我们设定句子的最大长度为6,如果输入句子是你好,我们会选择【好】字对应的向量来做分类,而经过padding后输入变成[ 你 好 ],这时我们需要指定cls_index=3,即指定index=3对应的词的向量来分类。而SequenceSummary就是根据我们指定cls_index从hidden state中选出向量来分类

2. 源码

def __init__(self, config: PretrainedConfig):
        super().__init__()

        self.summary_type = getattr(config, "summary_type", "last")
        if self.summary_type == "attn":
            # We should use a standard multi-head attention module with absolute positional embedding for that.
            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
            raise NotImplementedError

        self.summary = Identity()
        if hasattr(config, "summary_use_proj") and config.summary_use_proj:
            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
                num_classes = config.num_labels
            else:
                num_classes = config.hidden_size
            self.summary = nn.Linear(config.hidden_size, num_classes)

        activation_string = getattr(config, "summary_activation", None)
        self.activation: Callable = get_activation(activation_string) if activation_string else Identity()

        self.first_dropout = Identity()
        if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
            self.first_dropout = nn.Dropout(config.summary_first_dropout)

        self.last_dropout = Identity()
        if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
            self.last_dropout = nn.Dropout(config.summary_last_dropout)

def forward(
        self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
    ) -> torch.FloatTensor:
        """
        Compute a single vector summary of a sequence hidden states.

        Args:
            hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`):
                The hidden states of the last layer.
            cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`):
                Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification
                token.

        Returns:
            :obj:`torch.FloatTensor`: The summary of the sequence hidden states.
        """
        if self.summary_type == "last":
            output = hidden_states[:, -1]
        elif self.summary_type == "first":
            output = hidden_states[:, 0]
        elif self.summary_type == "mean":
            output = hidden_states.mean(dim=1)
        elif self.summary_type == "cls_index":
            if cls_index is None:
                cls_index = torch.full_like(
                    hidden_states[..., :1, :],
                    hidden_states.shape[-2] - 1,
                    dtype=torch.long,
                )
            else:
                cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
                cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
            output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
        elif self.summary_type == "attn":
            raise NotImplementedError

        output = self.first_dropout(output)
        output = self.summary(output)
        output = self.activation(output)
        output = self.last_dropout(output)

        return output

从源码中可以看出我们可以通过指定summary_type的值来选择获取相关向量,summary_type可选有["last", "first", "cls_index"], 即选最后一个字,第一个字,或者直接指定句子中相关字的index。如果指定cls_index,会通过gather方法获取到指定index的embedding。

源码节选
output = hidden_states.gather(-2, cls_index).squeeze(-2)  # shape (bsz, XX, hidden_size)
....
output = self.first_dropout(output)
output = self.summary(output)
output = self.activation(output)
output = self.last_dropout(output)

上面的节选代码可以看出整个流程是先取出指定字的向量放在output中,然后dropout,再通过summary()方法,summary在__init__()中定义为self.summary = nn.Linear(config.hidden_size, num_classes), 即一个全连接分类,最后通过激活函数和dropout得到输出。

3. 实验

class GPT2ClsNews(GPT2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        config.num_labels = 14
        self.transformer = GPT2Model(config)
        self.cls_head = SequenceSummary(config)

    def forward(self,
                input_ids=None,
                token_type_ids=None,
                position_ids=None,
                mc_token_ids=None,
                labels=None):
        transformer_outputs = self.transformer(input_ids)
        hidden_states = transformer_outputs[0]
        cls_logits = self.cls_head(hidden_states, mc_token_ids).squeeze(-1)
        if labels != None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(cls_logits, labels)
            return loss, cls_logits
        return cls_logits

如上我们输入SequenceSummary的是hidden_states和指定的词的索引mc_token_ids

hidden_states.shape: [4, 902, 768]
mc_token_ids.shape: [4]

进入SequenceSummary内部,首先通过cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)改变cls_index的形状,此mc_token_ids.shape = [4, 1, 1]

cls_index:
tensor([[[776]],
        [[901]],
        [[ 85]],
        [[627]]], device='cuda:5')

在这里插入图片描述
如上图我们就是要取出每个矩阵黄色行对应的向量,拿去分类。在SequenceSummary内部的output = hidden_states.gather(-2, cls_index).squeeze(-2)中取出的output.shape = [4, 768], 最后通过全连接层分类,即4个样本,每个样本都有一个768维的向量表示。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值