自动寻找Prompt
实验版本好多参数可调
import os
import torch
import logging
import datasets
import transformers
import numpy as np
import torch.nn as nn
from sklearn import metrics
from datasets import Dataset
from torch.nn import CrossEntropyLoss
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import Trainer, TrainingArguments, BertTokenizer, BertForMaskedLM
from transformers.modeling_outputs import MaskedLMOutput
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
transformers.set_seed(1)
logging.basicConfig(level=logging.INFO)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
prp_len = 2 #prompt token长度
# 通过LSTM寻找prompt的embedding
class MyModel(BertForMaskedLM):
def __init__(self, config):
super().__init__(config)
self.dim = 384
self.emb = nn.Embedding(prp_len+1, self.dim)
self.bi_lstm = nn.LSTM(self.dim, self.dim, 2, bidirectional=True)
self.b_emb = self.get_input_embeddings()
self.line1 = nn.Linear(768, 768)
self.line2 = nn.Linear(768, 768)
self.line3 = nn.Linear(768, 768)
self.relu = nn.ReLU()
def forward(
self,
input_ids=None, # [CLS] e(p) e(p) [MASK] e(input_ids)
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None, # [CLS] -100 -100 label e(input_ids)
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
p = self