1. 概述
SequenceSummary主要作用是从各种预训练模型的hidden state中抽取出向量来分类。例如在BERT中我们通常选取句子中第一字也就是[CLS]对应的向量来分类,而GPT2中我们会选择最后一个字来做分类。当然,在GPT2中不能直接取最后一个字,比如我们设定句子的最大长度为6,如果输入句子是你好
,我们会选择【好】字对应的向量来做分类,而经过padding后输入变成[ 你 好 ],这时我们需要指定cls_index=3,即指定index=3对应的词的向量来分类。而SequenceSummary就是根据我们指定cls_index从hidden state中选出向量来分类
2. 源码
def __init__(self, config: PretrainedConfig):
super().__init__()
self.summary_type = getattr(config, "summary_type", "last")
if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise NotImplementedError
self.summary = Identity()
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)
activation_string = getattr(config, "summary_activation", None)
self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
self.first_dropout = Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
self.first_dropout = nn.Dropout(config.summary_first_dropout)
self.last_dropout = Identity()
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)
def forward(
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
) -> torch.FloatTensor:
"""
Compute a single vector summary of a sequence hidden states.
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`):
Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification
token.
Returns:
:obj:`torch.FloatTensor`: The summary of the sequence hidden states.
"""
if self.summary_type == "last":
output = hidden_states[:, -1]
elif self.summary_type == "first":
output = hidden_states[:, 0]
elif self.summary_type == "mean":
output = hidden_states.mean(dim=1)
elif self.summary_type == "cls_index":
if cls_index is None:
cls_index = torch.full_like(
hidden_states[..., :1, :],
hidden_states.shape[-2] - 1,
dtype=torch.long,
)
else:
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
elif self.summary_type == "attn":
raise NotImplementedError
output = self.first_dropout(output)
output = self.summary(output)
output = self.activation(output)
output = self.last_dropout(output)
return output
从源码中可以看出我们可以通过指定summary_type
的值来选择获取相关向量,summary_type可选有["last", "first", "cls_index"]
, 即选最后一个字,第一个字,或者直接指定句子中相关字的index。如果指定cls_index
,会通过gather
方法获取到指定index的embedding。
源码节选
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
....
output = self.first_dropout(output)
output = self.summary(output)
output = self.activation(output)
output = self.last_dropout(output)
上面的节选代码可以看出整个流程是先取出指定字的向量放在output
中,然后dropout,再通过summary()方法,summary在__init__()
中定义为self.summary = nn.Linear(config.hidden_size, num_classes)
, 即一个全连接分类,最后通过激活函数和dropout得到输出。
3. 实验
class GPT2ClsNews(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
config.num_labels = 14
self.transformer = GPT2Model(config)
self.cls_head = SequenceSummary(config)
def forward(self,
input_ids=None,
token_type_ids=None,
position_ids=None,
mc_token_ids=None,
labels=None):
transformer_outputs = self.transformer(input_ids)
hidden_states = transformer_outputs[0]
cls_logits = self.cls_head(hidden_states, mc_token_ids).squeeze(-1)
if labels != None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(cls_logits, labels)
return loss, cls_logits
return cls_logits
如上我们输入SequenceSummary
的是hidden_states和指定的词的索引mc_token_ids
。
hidden_states.shape: [4, 902, 768]
mc_token_ids.shape: [4]
进入SequenceSummary
内部,首先通过cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
改变cls_index的形状,此mc_token_ids.shape = [4, 1, 1]
cls_index:
tensor([[[776]],
[[901]],
[[ 85]],
[[627]]], device='cuda:5')
如上图我们就是要取出每个矩阵黄色行对应的向量,拿去分类。在SequenceSummary
内部的output = hidden_states.gather(-2, cls_index).squeeze(-2)
中取出的output.shape = [4, 768], 最后通过全连接层分类,即4个样本,每个样本都有一个768维的向量表示。