基于Prompt的MLM文本分类-v2

本文研究了基于Prompt的MLM(Masked Language Modeling)文本分类方法,通过自动寻找Prompt并调整多种参数,如LSTM层数和学习率。实验结果显示,即使在小样本情况下,该方法相比直接使用BERT进行分类仍能取得一定优势,提供了一种新的文本分类策略。
摘要由CSDN通过智能技术生成

自动寻找Prompt

实验版本好多参数可调

import os
import torch
import logging
import datasets
import transformers
import numpy as np
import torch.nn as nn
from sklearn import metrics
from datasets import Dataset
from torch.nn import CrossEntropyLoss
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import Trainer, TrainingArguments, BertTokenizer, BertForMaskedLM
from transformers.modeling_outputs import MaskedLMOutput

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
transformers.set_seed(1)
logging.basicConfig(level=logging.INFO)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
prp_len = 2 #prompt token长度

# 通过LSTM寻找prompt的embedding
class MyModel(BertForMaskedLM):
    def __init__(self, config):
        super().__init__(config)
        self.dim = 384
        self.emb = nn.Embedding(prp_len+1, self.dim)
        self.bi_lstm = nn.LSTM(self.dim, self.dim, 2, bidirectional=True)
        self.b_emb = self.get_input_embeddings()
        self.line1 = nn.Linear(768, 768)
        self.line2 = nn.Linear(768, 768)
        self.line3 = nn.Linear(768, 768)
        self.relu = nn.ReLU()
        
    def forward(
        self,
        input_ids=None,  # [CLS] e(p) e(p) [MASK] e(input_ids)
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,   # [CLS] -100 -100 label e(input_ids)
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        
        p = self
  • 6
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 19
    评论
Prompt Learning是一种用于文本分类的方法,它将不同的自然语言处理任务转化为文本分类问题。与传统的BERT fine-tuning方法不同,Prompt Learning使用了一些特殊的技术和策略来提高分类效果。 在Prompt Learning中,有几个重要的概念,包括Template、Verbalizer和PromptModel。Template是一种用于构造提示的模板,它定义了输入文本和输出标签之间的关系。Verbalizer是一种将标签映射到自然语言描述的方法,它用于生成提示中的掩码。PromptModel是一个基于提示学习的文本分类模型,它使用预训练语言模型学习到的特征和标签文本来初始化分类器参数。 Prompt Learning适用于标注成本高、标注样本较少的文本分类场景,尤其在小样本场景中表现出更好的效果。它能够充分利用预训练语言模型学习到的特征和标签文本,从而降低样本量需求。此外,PaddleNLP还集成了一些前沿策略,如R-Drop和RGL,以帮助提升模型效果。 总之,Prompt Learning是一种用于文本分类的方法,通过构造提示和利用预训练语言模型的特征来提高分类效果,特别适用于标注成本高、标注样本较少的场景。 #### 引用[.reference_title] - *1* [Prompt-Learning](https://blog.csdn.net/weixin_42223207/article/details/122954172)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [效果超强!基于Prompt Learning、检索思路实现文本分类,开源数据增强、可信增强技术](https://blog.csdn.net/PaddlePaddle/article/details/126968241)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值