OpenPrompt关于软提示(softTemplate) 篇一

1.首先看看softTemplate的一个执行过程。

当创建softTemplate对象实例后,首先会初始化参数。他的init方法如下:

    def __init__(self,
                 model: PreTrainedModel,
                 tokenizer: PreTrainedTokenizer,
                 text: Optional[str] = None,
                 soft_embeds: Optional[torch.FloatTensor] = None,
                 num_tokens: int=20,
                 initialize_from_vocab: Optional[bool] = True,
                 random_range: Optional[float] = 0.5,
                 placeholder_mapping: dict = {'<text_a>':'text_a','<text_b>':'text_b'},
                ):
        super().__init__(tokenizer=tokenizer,
                         placeholder_mapping=placeholder_mapping)
        self.raw_embedding = model.get_input_embeddings()#获取预训练模型嵌入层的权重参数
        self.raw_embedding.requires_grad_(False)#提示不需要更新
        self.model_is_encoder_decoder = model.config.is_encoder_decoder#用于表示模型是否为编码器-解码器结构
        self.random_range = random_range
        self.num_tokens = num_tokens
        self.initialize_from_vocab = initialize_from_vocab

        self.text = text#加载提示和文本,如果没有指定,可能会从from_file加载
        # self.default_text1 = {"placeholder<text_a> <mask>"
        # self.default_text2 = "<text_a> <text_b> <mask>".split()

        if soft_embeds is not None:
            self.soft_embeds = soft_embeds
            self.num_tokens = len(soft_embeds)
        else:#如果没有预先存储的已经得到的软提示初始化参数,并且在软提示前的token个数不为0
            if self.num_tokens>0:
                self.generate_parameters()

这个类的init上面有一段说明:

this is the implementation of `The Power of Scale for Parameter-Efficient
Prompt Tuning <https://arxiv.org/pdf/2104.08691v1.pdf>`_ . Similar to :obj:`PrefixTuningTemplate`,
This template also does not need any textual template. Addition tokens are directly
concatenated into the input ids. There are two initializations of the new tokens.
(1). random initialization. (2) initialize with the tokens of the plm (We simply take
the first n_tokens similar to their implementation).

进入到generate_parameters()方法,这个方法会创建这两种软提示的初始化参数:

    def generate_parameters(self) -> None:
        """
        generate parameters needed for soft tokens embedding in soft-prompt
        for soft tokens, use a new embedding layer which is initialized with their corresponding embedding of hard tokens
        """
        if self.initialize_from_vocab:
            soft_embeds = self.raw_embedding.weight[:self.num_tokens].clone().detach()#这段代码的最终作用是获取模型输入嵌入层中的前self.num_tokens个嵌入向量权重,用这些权重作为softprompt的初始化参数
        else:
            soft_embeds = torch.FloatTensor(self.num_tokens, self.raw_embedding.weight.size(1)).uniform_(-self.random_range, self.random_range)#这段代码的作用是创建一个大小为self.num_tokens行、768列的浮点型张量,并对其进行均匀分布的随机初始化,取值范围为(-0.5,0.5)
        self.soft_embeds = nn.Parameter(soft_embeds, requires_grad=True)

之后就要进入到具体的子任务中去了,以分类任务为例子,他会在类别PromptForClassification设置提示模板。之后在forward中处理模板,具体方法为process_batch:

 def forward(self, batch: Union[Dict, InputFeatures]) -> torch.Tensor:
        r"""
        This is a forward method to make wrapped input data go through the model, and return the output logits.
        Typically, this function aims to predict the ``<mask>`` position.

        Args:
            batch (:obj:`Union[Dict, InputFeatures]`): The input features of batchified data sequences.
        """
        batch = self.template.process_batch(batch)
        input_batch = {key: batch[key] for key in batch if key in self.forward_keys}
        outputs = self.plm(**input_batch, output_hidden_states=True)
        outputs = self.template.post_processing_outputs(outputs)
        return outputs
进入process_batch方法,发现是个抽象方法,所以会由他的子类实现,即在SoftPrompt中可以看到:
    def process_batch(self, batch: Union[Dict, InputFeatures]) -> Union[Dict, InputFeatures]:
        """
        Convert input_ids to inputs_embeds
        for normal tokens, use the embedding layer of PLM
        for soft tokens, use a new embedding layer which is initialized with their corresponding embedding of hard tokens
        """
        inputs_embeds = self.raw_embedding(batch['input_ids'])
        batch_size = inputs_embeds.size(0)
        if self.num_tokens>0:
            soft_embeds = self.soft_embeds.repeat(batch_size, 1, 1)
#repeat是PyTorch库中的一个张量操作方法,用于复制并扩展张量的某一维度。这里的soft_embeds.repeat(batch_size, 1, 1)表示将soft_embeds在第一个维度上复制batch_size次,而在第二个和第三个维度(假设它们分别代表词嵌入的长度和嵌入向量的维度)上保持不变。目的是为了将我们之前的软提示嵌入复制batch_size次,拼接在每个样本前面
            inputs_embeds = torch.cat([soft_embeds, inputs_embeds], 1)

        batch['input_ids'] = None
        batch['inputs_embeds'] = inputs_embeds
        if 'attention_mask' in batch and self.num_tokens>0:
            am = batch['attention_mask']
            batch['attention_mask'] = torch.cat([torch.ones((batch_size,self.num_tokens), dtype = am.dtype,device=am.device), am], dim=-1)
        return batch

这里还有一个很重要的地方需要注明,即提示的长度和文本的长度超过模型最大长度该如何处理。

在openprompt中,他们两个是分开处理,之后在刚刚的process_batch中组装起来的。所以说,如果你设置软提示的长度为n,那么你的输入文本就只有max_sentence_length - n的长度。这里处理文本的长度主要是使用truncate_method。

流程我们可以看到第一步:

train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,
                                    tokenizer_wrapper_class=WrapperClass, max_seq_length=max_length,
                                    batch_size=batch_size, shuffle=True, teacher_forcing=False, predict_eos_token=False,
                                    truncate_method="head")

这里设置了truncate的不同方式,我们进入这个类里面去查看,可以看到init中定义了这样一段代码:

    tokenizer_wrapper_init_keys = signature(tokenizer_wrapper_class.__init__).args
            prepare_kwargs = {
                "max_seq_length" : max_seq_length,
                "truncate_method" : truncate_method,
                "decoder_max_length" : decoder_max_length,
                "predict_eos_token" : predict_eos_token,
                "tokenizer" : tokenizer,
                **kwargs,
            }

            to_pass_kwargs = {key: prepare_kwargs[key] for key in prepare_kwargs if key in tokenizer_wrapper_init_keys}
            self.tokenizer_wrapper = tokenizer_wrapper_class(**to_pass_kwargs)

说明最终的处理是在tokenizer_wrapper_class类中处理的,这个类可以在我们加载模型的配置上看到为MLMTokenizerWrapper类,这个继承了TokenizerWrapper类,这个里面就有truncate的处理逻辑。

    def truncate(self, encoder_inputs):
        total_tokens = sum([len(part) for part in encoder_inputs['input_ids']])
        num_specials = self.num_special_tokens_to_add
        num_tokens_to_truncate = total_tokens - self.max_seq_length + num_specials#可以很清楚的看到截取的时候并没有考虑softTemplate的部分
        self.total_passed_sentences+=1
        if num_tokens_to_truncate>0:
            self.num_truncated_sentences += 1
            encoder_inputs = self.truncate_fct(input_dict=encoder_inputs,
                          num_tokens_to_truncate=num_tokens_to_truncate)
        return encoder_inputs

这个方法的调用是在PromptDataloader中的tokenize中,这个方面里面又调用了takenize_one_example方法,这个方法的实现是在MLMTokenizerWrapper类中。

最后我们回到分类的之前的forward方法中,里面有这样一行代码:

outputs = self.template.post_processing_outputs(outputs)

点进这个方法可以看到:

    def post_processing_outputs(self, outputs: torch.Tensor):
        r"""Post processing the outputs of language models according
        to the need of template. Most templates don't need post processing,
        The template like SoftTemplate, which appends soft template as a module
        (rather than a sequence of input tokens) to the input,
        should remove the outputs on these positions to keep the seq_len the same
        """
        if not self.model_is_encoder_decoder:
            outputs.logits = outputs.logits[:, self.num_tokens:,: ]
        return outputs

在这里模型将前n个token输出移除,去掉软提示以便于之后的映射器部分进行正确分类,至此已经将softTemplate全部讲完了。

  • 9
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值