import copy
import os
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(0, dir_path)
import contextlib
import torch.utils.checkpoint
from torch.nn import LayerNorm
from torch import nn
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modeling_perceive_sampler import BertConfig, BertLMHeadModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
import transformers
from transformers import PreTrainedModel, AutoTokenizer, AutoModelForMaskedLM,AutoModel,BertTokenizer,GPT2LMHeadModel,PretrainedConfig,GPT2Model,GPT2Tokenizer,LongformerTokenizer, LongformerModel
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 使用第一个GPU
import argparse
import math
class RALLM(nn.Module):
def __init__(self,args):
super(RALLM,self).__init__()
self.is_compress = args.is_compress
self.use_lora = args.use_lora
print('Init LLM ... ')
if args.LLM_model == "Baichuan2_13B":
self.LLM_model_name = "Baichuan2-13B-Chat"
self.LLM_hidden_size = 5120
elif args.LLM_model == "Baichuan2_7B":
self.LLM_model_name = "baichuan2_7B"
self.LLM_hidden_size = 4096
self.LLM_model = transformers.AutoModelForCausalLM.from_pretrained(
self.LLM_model_name,
device_map=f"cuda:{args.local_rank}",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
# cache_dir=training_args.cache_dir,
)
self.LLM_tokenizer = transformers.AutoTokenizer.from_pretrained(
self.LLM_model_name,
use_fast=False,
trust_remote_code=True,
model_max_length=4096,
# cache_dir=training_args.cache_dir,
)
self.flag_context_start = nn.Parameter(torch.zeros([1, 1, self.LLM_hidden_size]))#.to(self.device)
self.flag_context_end = nn.Parameter(torch.zeros([1, 1, self.LLM_hidden_size]))#.to(self.device)
self.flag_context_start.requires_grad = False
self.flag_context_end.requires_grad = False
self.device = self.LLM_model.device
self.user_token = self.LLM_tokenizer._convert_id_to_token(195)
self.assisent_token = self.LLM_tokenizer._convert_id_to_token(196)
self.eoa = self.LLM_tokenizer._convert_id_to_token(2)
print("user_token:",self.user_token,"assisent_token:",self.assisent_token,"eoa:",self.eoa)
print('Done')
print('Init context encoder ... ')
self.init_context_encoder(args)
print('Done')
def init_Qformer(self,num_query_token,num_features):
self.Qformer = self.init_qformer(num_query_token, num_features,cross_attention_freq=1)
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.Qformer.cls = None
@classmethod
def init_qformer(cls,
num_query_token,
vision_width,
cross_attention_freq=2,
pretrain=True):
encoder_config = BertConfig()
encoder_config.num_hidden_layers = 2
encoder_config.hidden_size = vision_width
encoder_config.encoder_width = vision_width
encoder_config.num_attention_heads = vision_width//64
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
return Qformer
def init_context_encoder(self,args):
num_query_token = args.query_tokens = 0
if args.encoder == "bert_base":
self.context_tokenizer = AutoTokenizer.from_pretrained("bert_base_chinese")
self.context_encoder = AutoModelForMaskedLM.from_pretrained("bert_base_chinese",output_hidden_states=True)
num_features = 768
if args.encoder == "bert_large":
self.context_tokenizer = AutoTokenizer.from_pretrained("bert_large_chinese",max_length=2000)
self.context_encoder = AutoModelForMaskedLM.from_pretrained("bert_large_chinese",output_hidden_states=True)
num_features = 1024
if args.encoder == "gpt2_xlarge":
self.context_tokenizer = BertTokenizer.from_pretrained("gpt2_chinese_xlarge")
self.context_encoder = GPT2LMHeadModel.from_pretrained("gpt2_chinese_xlarge")
num_features = 1600
if args.encoder == "gpt2_large":
self.context_tokenizer = BertTokenizer.from_pretrained("gpt2_chinese_large")
self.context_encoder = GPT2LMHeadModel.from_pretrained("gpt2_chinese_large")
num_features = 1280
if args.encoder == "gpt2_large_en":
self.context_tokenizer = GPT2Tokenizer.from_pretrained("/data2/xinyuuliu/Baichuan2_qformer_bert/gpt2-large-EN")
self.context_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.context_encoder = GPT2Model.from_pretrained("/data2/xinyuuliu/Baichuan2_qformer_bert/gpt2-large-EN")
num_features = 1280
if args.encoder == "longformer":
self.context_tokenizer = LongformerTokenizer.from_pretrained('longformer')
self.context_encoder = LongformerModel.from_pretrained('longformer')
num_features = 768
if args.encoder == "longformer_large":
self.context_tokenizer = LongformerTokenizer.from_pretrained('longformer-large')
self.context_encoder = LongformerModel.from_pretrained('longformer-large')
num_features = 1024
# bert_tokenizer = AutoTokenizer.from_pretrained("bert_base_chinese",max_length=2000)
# bert_encoder = AutoModelForMaskedLM.from_pretrained("longformer_zh",output_hidden_states=True) #.to(device)
self.context_encoder = self.context_encoder.to(self.device)
self.context_score = torch.nn.ModuleList([
torch.nn.Linear(num_features, 64),
torch.nn.Tanh(),
torch.nn.Linear(64, 1),
]) # 768是BERT的隐藏状态维度,1是目标输出维度
self.context2llm_proj = torch.nn.Linear(num_features, self.LLM_hidden_size) # 768是BERT的隐藏状态维度,1是目标输出维度
self.llm_proj = torch.nn.Linear(num_features, self.LLM_hidden_size)
# model.embed2qformer_proj = torch.nn.Linear(num_features, 768)
self.ln_features = LayerNorm(num_features)
self.init_Qformer(num_query_token,num_features)
# del model.internlm_proj
# del model.Qformer
# torch.cuda.empty_cache() # 释放显存
# if device:
# model = self.model.to(self.device)
def encode_text(self, text, add_special_tokens=False):
input_ids = self.LLM_tokenizer.encode(text)
input_ids = torch.LongTensor([input_ids]).to(self.device)
if self.use_lora:
text_embeds = self.LLM_model.base_model.model.model.embed_tokens(input_ids)
else:
text_embeds = self.LLM_model.model.embed_tokens(input_ids)
return text_embeds
def calculate_compressibility(self,x,k=0):
return (x * k*(9 / 1000) + 1) * 111.111 / (x + 111.111)
# 批量输入句子
def batch_input_sentences(self,sentences):
input_ids_list = [self.context_tokenizer.encode(sentence,return_tensors="pt",padding='max_length', max_length=2500, truncation=True) for sentence in sentences]
max_length = max(len(input_ids[0]) for input_ids in input_ids_list)
input_ids_padded = [torch.cat([input_ids, torch.zeros(1, max_length - input_ids.size(1), dtype=torch.long)], dim=1) for input_ids in input_ids_list]
input_ids_tensor = torch.cat(input_ids_padded, dim=0)
return input_ids_tensor
def encode_context(self, text_list):
if text_list is None:
return None
inputs_LLMs = []
input_atts = []
# print(text_list)
for text in text_list:
# input_ids = self.context_tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
# 对文本列表进行编码并进行最长的填充
# encoded_ids = self.context_tokenizer(text, padding=True, return_tensors="pt", truncation=True)
input_ids = self.batch_input_sentences(text)
input_ids =input_ids.to(self.device)
# input_ids = encoded_ids.data["input_ids"].to(self.device)
# attention_mask = encoded_ids.data["attention_mask"].to(self.device)
outputs = self.context_encoder(input_ids,output_hidden_states=True)
# 提取最后一层的隐藏状态向量
embedding,last_hidden_state = outputs.hidden_states[0],outputs.hidden_states[-1] #outputs.logits
x = last_hidden_state
for layer in self.context_score:
x = layer(x)
output = x
# output = self.context_score(last_hidden_state) # 进行线性变换
batch,seq_len,ebd_dim = last_hidden_state.size()
# compressibility = -0.0009 * seq_len+1 #压缩率计算长度越低压缩率越低,长度越长,压缩率越高。线性压缩不好 x*f(x) 不是单调递减的
# compressibility = 111.111/(seq_len+111.111) #重新设计非线性压缩 10以下不压缩,0-1000 x*f(x) 递减
if self.is_compress:
compressibility = self.calculate_compressibility(seq_len,0)
K = math.ceil(seq_len*compressibility)
else:
K = seq_len
# 使用 torch.topk 函数获取 top k 的索引
topk_indices = torch.topk(output, K,dim=1).indices
# print(topk_indices)
topk_indices, sorted_indices = torch.sort(topk_indices,dim=1) #恢复原文顺序
# print(topk_indices)
# 计算 top k 对应的 last_hidden_state
topk_selected_last_hidden_state = torch.gather(last_hidden_state, 1, topk_indices.expand(-1, -1, ebd_dim))
# print(last_hidden_state)
# print(topk_selected_last_hidden_state)
topk_selected_embedding = torch.gather(embedding, 1, topk_indices.expand(-1, -1, ebd_dim))
# bert_text_atts = torch.gather(attention_mask, 1, torch.squeeze(topk_indices, dim=2))
bert_text_embeds = self.ln_features(last_hidden_state)
bert_text_atts = torch.ones(bert_text_embeds.size()[:-1],dtype=torch.long).to(self.device)
# query_tokens = self.query_tokens.expand(bert_text_atts.shape[0], -1,-1)
# query_tokens = topk_selected_embedding
query_tokens = topk_selected_last_hidden_state
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=bert_text_embeds,
encoder_attention_mask=bert_text_atts,
return_dict=True,
)
# topk_context_hidden_state = self.context2llm_proj(topk_selected_last_hidden_state)
inputs_LLM = self.llm_proj(query_output.last_hidden_state)
inputs_LLM = torch.cat([
self.flag_context_start.expand(batch, -1, -1),
# topk_context_hidden_state,
inputs_LLM,
self.flag_context_end.expand(batch, -1, -1)
],dim=1).view(-1, self.LLM_hidden_size)
input_att = torch.cat([torch.ones((batch,1)).to(self.device),bert_text_atts,torch.ones((batch,1)).to(self.device)],dim=1).view(-1)
# print(inputs_LLM.shape)
inputs_LLMs.append(inputs_LLM)
input_atts.append(input_att)
# context_inputs = torch.stack(inputs_LLMs)
return inputs_LLMs,input_atts
def wrap_prompt(self,
text_embeds,
context_embeds=None,
history=None,
add_special=True):
if add_special:
if history is None:
prompt_segs = [
self.user_token,
self.assisent_token
]
else:
prompt_segs = [self.user_token, self.assisent_token]
else:
prompt_segs = [self.user_token, self.assisent_token] # used in wrap history
prompt_seg_embeds = []
for i, seg in enumerate(prompt_segs):
if history is not None:
add_special_tokens = False
else:
add_special_tokens = i == 0
seg_embeds = self.encode_text(
seg, add_special_tokens=add_special_tokens)
prompt_seg_embeds.append(seg_embeds)
if context_embeds is None:
context_embeds = text_embeds.new_empty(text_embeds.size(0), 0,
text_embeds.size(-1))
else:
# 在第一个维度(索引为0)添加一个维度
context_embeds = context_embeds[0].unsqueeze(0)
prompt_seg_embeds = [
prompt_seg_embeds[0], text_embeds,context_embeds, prompt_seg_embeds[1]
]
prompt_embeds = torch.cat(prompt_seg_embeds, dim=1)
if history is not None:
prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
return prompt_embeds
def generate(self, text, context=None, **kwargs):
text = text.replace("<context>","").replace(self.user_token,"").replace(self.assisent_token,"")
text_embeds = self.encode_text(text)
context_embeds,_ = self.encode_context(context)
prompt_embeds = self.wrap_prompt(text_embeds, context_embeds)
# out_embeds = self.LLM_model.generate(input_ids=None,
# inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs))
# out_text = self.decode_text(out_embeds)
outputs = self.LLM_model.generate(input_ids=None,inputs_embeds=prompt_embeds, generation_config=self.LLM_model.generation_config)
response = self.LLM_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def chat(self, text, context=None, history=None, **kwargs):
text_embeds = self.encode_text(text)
img_embeds = self.encode_context(context)
prompt_embeds = self.wrap_prompt(text_embeds,
img_embeds,
history=history)
out_embeds = self.internlm_model.generate(
inputs_embeds=prompt_embeds, **self.get_gen_args(**kwargs))
out_text = self.decode_text(out_embeds)
# trunc at eoh and eoa
clean_out_text_token_ids = self.tokenizer(
out_text, return_tensors='pt').input_ids.to(self.device)
clean_out_text_embeds = self.internlm_model.model.embed_tokens(
clean_out_text_token_ids)
clean_prompt_embeds = self.wrap_prompt(text_embeds,
img_embeds,
add_special=False)
cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds],
dim=1)
if history is None:
history = []
history.append(cur_history)
return out_text, history
def align_text(self, samples, has_context=False): ### add eos and eoa 返回<context>后的text
text_new = []
if has_context: ### remove the first user to wrap image features
text = [
t.split("<context>")[-1] for t in samples["text_input"]
]
else:
text = [t for t in samples["text_input"]]
text = [t + self.eoa for t in text]
for i in range(len(text)):
temp = text[i]
# temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>')
# if temp.find(self.eoh) > temp.find(self.eoa):
# temp = temp.replace(self.eoa, '', 1)
text_new.append(temp)
return text_new
def prompt_wrap(self, context_embeds,context_atts, prompt_list):
batch_size = len(context_embeds)
p_before = [prompt.split('<context>')[0] for prompt in prompt_list]
p_before_tokens = self.LLM_tokenizer(p_before,
padding=True,
truncation=True,
return_tensors="pt",
add_special_tokens=True).to(
self.device)
if self.use_lora:
p_before_embeds = self.LLM_model.base_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
else:
p_before_embeds = self.LLM_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
# wrapped_context_embeds = torch.cat([p_before_embeds, context_embeds], dim=1)
# wrapped_context_embeds = torch.cat([p_before_embeds]+context_embeds, dim=1)
wrapped_context_embeds = []
wrapped_atts_context = []
wrapped_target = []
for i, (context_embed,context_att) in enumerate(zip(context_embeds,context_atts)):
# 将p_before_embeds的每个序列与相应的张量在序列长度维度上拼接
concatenated = torch.cat((p_before_embeds[i], context_embed), dim=0)
wrapped_context_embeds.append(concatenated)
# concatenated_att = torch.cat((torch.ones(p_before_embeds[i].size()[:-1],dtype=torch.long).to(self.device),context_att),dim=0)
wrapped_atts_context.append(torch.ones(concatenated.size()[:-1],dtype=torch.long).to(self.device))
# wrapped_atts_context.append(concatenated_att)
target = torch.ones(concatenated.size()[:-1], dtype=torch.long) * -100
target[0] = 2
target = target.to(self.device)
wrapped_target.append(target)
# wrapped_atts_context = torch.ones(wrapped_context_embeds.size()[:-1],
# dtype=torch.long).to(self.device)
# wrapped_target = torch.ones(
# batch_size, wrapped_context_embeds.shape[1], dtype=torch.long).to(
# self.device) * -100
return wrapped_context_embeds, wrapped_atts_context, wrapped_target
def text2emb(self, text):
to_regress_tokens = self.LLM_tokenizer(text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=4096,
add_special_tokens=False).to(
self.device)
targets = self.mask_human_targets(to_regress_tokens.input_ids)
targets = targets.to(self.device)
return to_regress_tokens, targets
def mask_human_targets(self, input_ids, pure=False):
target_batch = []
for bs in range(input_ids.shape[0]):
cur_idx = 0
ids = input_ids[bs]
targets = copy.deepcopy(ids)
last_eoa = 0
last_eoh = 0
for i, temp_id in enumerate(ids):
if temp_id == 196: #### end of human
targets[cur_idx:i+1] = -100
target_batch.append(targets.unsqueeze(0))
target_batch = torch.cat(target_batch, dim=0)
target_batch[target_batch==0]=-100
# print(input_ids)
# print(target_batch)
return target_batch
def forward(self,
input_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
context = None,
text_input = None,
**kwargs):
# samples = kwargs #.get('samples')
# has_context = 'context' in samples.keys()
if context:
has_context = True
else:
has_context = False
samples = {"text_input":text_input,"context":context}
### encode text
text = self.align_text(samples=samples, has_context=has_context) #获取<context> 后面的text
to_regress_tokens, targets = self.text2emb(text) #返回token和target
if self.use_lora:
to_regress_embeds = self.LLM_model.base_model.model.model.embed_tokens(to_regress_tokens.input_ids)
else:
to_regress_embeds = self.LLM_model.model.embed_tokens(to_regress_tokens.input_ids)
attention_mask = to_regress_tokens.attention_mask
if has_context:
prompt = samples["text_input"]
### encode context
context = samples["context"]
context_embeds,context_atts = self.encode_context(context)
context_embeds, atts_context, wrapped_target = self.prompt_wrap(
context_embeds,context_atts, prompt)
### combine text and image
to_regress_embeds_ = []
attention_mask_ = []
targets_ = []
for i, (tensor0,tensor1,tensor2) in enumerate(zip(to_regress_embeds,attention_mask,targets)):
# 将p_before_embeds的每个序列与相应的张量在序列长度维度上拼接
to_regress_embed = torch.cat((context_embeds[i], tensor0), dim=0)
to_regress_embeds_.append(to_regress_embed)
attention_m = torch.cat((atts_context[i], tensor1), dim=0)
attention_mask_.append(attention_m)
target = torch.cat((wrapped_target[i], tensor2), dim=0)
targets_.append(target)
# to_regress_embeds = torch.cat([context_embeds, to_regress_embeds],
# dim=1)
# attention_mask = torch.cat([atts_context, attention_mask], dim=1)
# targets = torch.cat([wrapped_target, targets], dim=1)
# 确定最大长度
max_len = max(t.size(0) for t in to_regress_embeds_)
# 填充张量
padded_to_regress_embeds_ = []
padded_attention_mask_ = []
padded_targets_ = []
for (t,a,l) in zip(to_regress_embeds_,attention_mask_,targets_):
if t.size(0) < max_len:
# 计算需要填充的长度
padding_size = max_len - t.size(0)
# 在序列维度上进行填充
padded_regress = torch.nn.functional.pad(t, (0, 0, 0, padding_size))
padded_attention = torch.nn.functional.pad(a, (0, padding_size), value=0)
padded_target = torch.nn.functional.pad(l, (0, padding_size), value=-100)
padded_to_regress_embeds_.append(padded_regress)
padded_attention_mask_.append(padded_attention)
padded_targets_.append(padded_target)
else:
padded_to_regress_embeds_.append(t)
padded_attention_mask_.append(a)
padded_targets_.append(l)
# 合并张量
to_regress_embeds = torch.stack(padded_to_regress_embeds_)
attention_mask = torch.stack(padded_attention_mask_)
targets = torch.stack(padded_targets_)
outputs = self.LLM_model(
inputs_embeds=to_regress_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=targets,
)
return outputs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output", default="output", type=str)
parser.add_argument("--encoder", default="gpt2_large", type=str)
parser.add_argument("--query_tokens", default=32, type=int)
parser.add_argument("--load_path", default="/data2/xinyuuliu/InternLM-XComposer/output_rerank", type=str)
parser.add_argument("--local_rank", default="0", type=str)
args = parser.parse_args()
model = RALLM(args)
print(model)
# model.encode_context("我爱北京天安门")
# model.encode_text("我爱北京天安门")
# #<ContextHere>
# query = "Q:请重复内容:<cont_s><ContextHere><cont_e> \n A:"
# context = ["电饭煲不知道怎么选?想要吃一碗香喷喷的米饭,除了米要好之外,还需要一款性能优秀的电饭煲,所以大家在选购电饭煲的时候,一定要多花点心思看看攻略避免踩雷。我前前后后给亲朋好友选购过不下5台电饭煲,也算是积攒了不少选购经验,今天特意总结了一下想分享给大家。1、容量选择市面上电饭煲容量普遍在3L-5L之间,这个范围的容量足够满足绝大部分家庭使用,3L一般可以满足1-3人的家庭,4L一般可以满足2-5人的家庭,5L一般可以满足2-8人的家庭,如果人口超过8人建议直接选择5L以上的容量,使用会更方便。"]
# model.interleav_wrap(query,context)
modeling_perceive_sampler.py
"""
* Copyright (c) 2023, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
* Based on huggingface code base
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
"""
import math
from typing import Tuple
import torch
import torch.utils.checkpoint
from torch import Tensor, device
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
)
from transformers.modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.models.bert.configuration_bert import BertConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size,
padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config,
"position_embedding_type",
"absolute")
self.config = config
def forward(
self,
input_ids=None,
position_ids=None,
query_embeds=None,
past_key_values_length=0,
):
if input_ids is not None:
seq_length = input_ids.size()[1]
else:
seq_length = 0
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length:
seq_length +
past_key_values_length].clone()
if input_ids is not None:
embeddings = self.word_embeddings(input_ids)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if query_embeds is not None:
embeddings = torch.cat((query_embeds, embeddings), dim=1)
else:
embeddings = query_embeds
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" %
(config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size /
config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config,
"position_embedding_type",
"absolute")
if (self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"):
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(
2 * config.max_position_embeddings - 1,
self.attention_head_size)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(
self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(
self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
mixed_query_layer = self.query(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
if (self.position_embedding_type == "relative_key"
or self.position_embedding_type == "relative_key_query"):
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length,
dtype=torch.long,
device=hidden_states.device).view(
-1, 1)
position_ids_r = torch.arange(seq_length,
dtype=torch.long,
device=hidden_states.device).view(
1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(
distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(
dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum(
"bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum(
"bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = (attention_scores +
relative_position_scores_query +
relative_position_scores_key)
attention_scores = attention_scores / math.sqrt(
self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (
self.all_head_size, )
context_layer = context_layer.view(*new_context_layer_shape)
outputs = ((context_layer, attention_probs) if output_attentions else
(context_layer, ))
outputs = outputs + (past_key_value, )
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads,
self.self.num_attention_heads,
self.self.attention_head_size,
self.pruned_heads,
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(
heads)
self.self.all_head_size = (self.self.attention_head_size *
self.self.num_attention_heads)
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,
) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
self.layer_num = layer_num
if (self.config.add_cross_attention
and layer_num % self.config.cross_attention_freq == 0):
self.crossattention = BertAttention(
config, is_cross_attention=self.config.add_cross_attention)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
self.intermediate_query = BertIntermediate(config)
self.output_query = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
query_length=0,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = (past_key_value[:2]
if past_key_value is not None else None)
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
assert (
encoder_hidden_states is not None
), "encoder_hidden_states must be given for cross-attention layers"
cross_attention_outputs = self.crossattention(
query_attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
query_attention_output = cross_attention_outputs[0]
outputs = (
outputs + cross_attention_outputs[1:-1]
) # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text],
dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output, ) + outputs
outputs = outputs + (present_key_value, )
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(self, attention_output):
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[BertLayer(config, i) for i in range(config.num_hidden_layers)])
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
query_length=0,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = (() if output_attentions
and self.config.add_cross_attention else None)
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[
i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing",
False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value,
output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1], )
if output_attentions:
all_self_attentions = all_self_attentions + (
layer_outputs[1], )
all_cross_attentions = all_cross_attentions + (
layer_outputs[2], )
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )
if not return_dict:
return tuple(v for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size,
config.vocab_size,
bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0,
std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=False):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(
self,
attention_mask: Tensor,
input_shape: Tuple[int],
device: device,
is_decoder: bool,
has_query: bool = False,
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = (seq_ids[None, None, :].repeat(
batch_size, seq_length, 1) <= seq_ids[None, :, None])
# add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[
1] - causal_mask.shape[1]
if has_query: # UniLM style attention mask
causal_mask = torch.cat(
[
torch.zeros(
(batch_size, prefix_seq_len, seq_length),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=1,
)
causal_mask = torch.cat(
[
torch.ones(
(batch_size, causal_mask.shape[1],
prefix_seq_len),
device=device,
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=-1,
)
extended_attention_mask = (causal_mask[:, None, :, :] *
attention_mask[:, None, None, :])
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})"
.format(input_shape, attention_mask.shape))
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(
dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = (output_attentions if output_attentions is not None
else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
# use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is None:
assert (
query_embeds is not None
), "You have to specify query_embeds when input_ids is None"
# past_key_values_length
past_key_values_length = (past_key_values[0][0].shape[2] -
self.config.query_length
if past_key_values is not None else 0)
query_length = query_embeds.shape[1] if query_embeds is not None else 0
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
query_embeds=query_embeds,
past_key_values_length=past_key_values_length,
)
input_shape = embedding_output.size()[:-1]
batch_size, seq_length = input_shape
device = embedding_output.device
if attention_mask is None:
attention_mask = torch.ones(
((batch_size, seq_length + past_key_values_length)),
device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if is_decoder:
extended_attention_mask = self.get_extended_attention_mask(
attention_mask,
input_ids.shape,
device,
is_decoder,
has_query=(query_embeds is not None),
)
else:
extended_attention_mask = self.get_extended_attention_mask(
attention_mask, input_shape, device, is_decoder)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
0].size()
else:
(
encoder_batch_size,
encoder_sequence_length,
_,
) = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size,
encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [
self.invert_attention_mask(mask)
for mask in encoder_attention_mask
]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape,
device=device)
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask)
else:
encoder_extended_attention_mask = self.invert_attention_mask(
encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask,
self.config.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
query_length=query_length,
)
sequence_output = encoder_outputs[0]
pooled_output = (self.pooler(sequence_output)
if self.pooler is not None else None)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [
r"position_ids", r"predictions.decoder.bias"
]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=True,
reduction="mean",
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns:
Example::
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> config = BertConfig.from_pretrained("bert-base-cased")
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
"""
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
if labels is not None:
use_cache = False
if past_key_values is not None:
query_embeds = None
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
query_embeds=query_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
)
sequence_output = outputs[0]
if query_embeds is not None:
sequence_output = outputs[0][:, query_embeds.shape[1]:, :]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores[:, :-1, :].contiguous()
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :
-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction=reduction,
label_smoothing=0.1)
lm_loss = loss_fct(
shifted_prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1),
)
if reduction == "none":
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
if not return_dict:
output = (prediction_scores, ) + outputs[2:]
return ((lm_loss, ) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self,
input_ids,
query_embeds,
past=None,
attention_mask=None,
**model_kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids":
input_ids,
"query_embeds":
query_embeds,
"attention_mask":
attention_mask,
"past_key_values":
past,
"encoder_hidden_states":
model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask":
model_kwargs.get("encoder_attention_mask", None),
"is_decoder":
True,
}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past), )
return reordered_past
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [
r"position_ids", r"predictions.decoder.bias"
]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
head_mask=None,
query_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_logits=False,
is_decoder=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = (return_dict if return_dict is not None else
self.config.use_return_dict)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
query_embeds=query_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
)
if query_embeds is not None:
sequence_output = outputs[0][:, query_embeds.shape[1]:, :]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1))
if not return_dict:
output = (prediction_scores, ) + outputs[2:]
return (((masked_lm_loss, ) +
output) if masked_lm_loss is not None else output)
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
dataset_batch.py
# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import torch
import json
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
import random
class QADataset(Dataset):
def __init__(self, data_path,train) -> None:
super().__init__()
self.data = []
data = pd.read_csv(data_path).dropna()
print(data.columns)
condition = (data['answer'].str.len() <= 1000) & (data['summary'].str.len() <= 500)
filtered_data = data[condition]
with open("data/corpus.tsv","r") as f_read:
corpus = [i.split()[-1] for i in f_read.readlines()]
retell_prompts = ["请复述这段被压缩的内容",
"复述这段被压缩的内容",
"请将被压缩的内容复述出来",]
summary_prompts = ["请总结被压缩的信息",
"还原被压缩信息的主要内容",
"请写出被压缩信息的主要内容",
"请对之前压缩的信息进行概括",
"请提炼出之前被压缩信息的核心要点",
"请归纳一下之前被压缩的内容的主旨"]
if train:
# 过滤出符合长度条件的文章
# filtered_data1000 = list(filter(self.filter_by_length1000, data["answer"]))
for idx in range(5000):
# if not line or line == "" or len(line) < 50 or len(line) > 2000:
# continue
# 随机确定重复次数(1到5次)
repeat_count = random.randint(1, 10)
flag_context = "<context> "*repeat_count
prompt = random.choice(retell_prompts)
selected_articles = random.sample(corpus, repeat_count)
selected_articles_ = "[SEP]".join(selected_articles)
text = f'<reserved_106>{prompt}{flag_context}<reserved_107>{selected_articles_}'
test_data = {"context":selected_articles,"text_input":text,"label":selected_articles_}
self.data.append(
test_data
)
# for idx in range(5000):
# repeat_count = random.randint(1, 1)
# flag_context = "<context> "*repeat_count
# selected_articles = random.sample(filtered_data150, repeat_count)
# selected_articles_ = " ".join(selected_articles)
# text = f'<|User|>:请复述这段话{flag_context} <|Bot|>:{selected_articles_}'
# test_data = {"samples":{"context":selected_articles,"text_input":[text]}}
# self.data.append(
# test_data
# )
for idx,(answer,summary) in enumerate(zip(filtered_data["answer"],filtered_data["summary"])):
answer = [answer[:1000]]
flag_context = "<context> "
prompt = random.choice(summary_prompts)
# user_token: <reserved_106> assisent_token: <reserved_107> eoa: </s>
text = f'<reserved_106>{prompt}{flag_context}<reserved_107>{summary}'
test_data = {"context":answer,"text_input":text,"label":summary}
self.data.append(
test_data
)
# for idx in range(10000):
# repeat_count = random.randint(1, 1)
# flag_context = "<context> "*repeat_count
# selected_articles = random.sample(filtered_data1500, repeat_count)
# selected_articles_ = " ".join(selected_articles)
# text = f'<|User|>:请复述这段话{flag_context} <|Bot|>:{selected_articles_}'
# test_data = {"samples":{"context":selected_articles,"text_input":[text]}}
# self.data.append(
# test_data
# )
print("data load , size:", len(self.data))
else:
for idx in range(100):
# if not line or line == "" or len(line) < 50 or len(line) > 2000:
# continue
# 随机确定重复次数(1到5次)
repeat_count = random.randint(3, 5)
flag_context = "<context> "*repeat_count
prompt = random.choice(retell_prompts)
selected_articles = random.sample(corpus, repeat_count)
selected_articles_ = "[SEP]".join(selected_articles)
text = f'<reserved_106>{prompt}{flag_context}<reserved_107>'
test_data = {"context":selected_articles,"text_input":text,"label":selected_articles_}
self.data.append(
test_data
)
# 创建一个函数来过滤文章长度
@staticmethod
def filter_by_length150(article):
return 180 <= len(article) <= 200
@staticmethod
def filter_by_length1000(article):
return 50 <= len(article) <= 1000
@staticmethod
def filter_by_length1500(article):
return 500 <= len(article) <= 1500
def __getitem__(self, index):
item_data = self.data[index]
return item_data
def __len__(self):
return len(self.data)
if __name__ == "__main__":
data_path = "QA_5000_summary.csv"
dataset = QADataset(data_path,train=True)
# print(dataset[0])
val_params = {
"batch_size": 2,
"shuffle": False,
"num_workers": 0,
}
def collate_fn(batch):
"""
对batch数据进行处理
:param batch: [一个getitem的结果,getitem的结果,getitem的结果]
:return: 元组
"""
# 初始化一个空字典来存储合并后的结果
merged_dict = {}
# 遍历列表中的每个字典
for d in batch:
# 遍历每个字典中的键值对
for key, value in d.items():
# 如果键已经存在于merged_dict中,将值合并为一个字符串,用逗号分隔
if key in merged_dict:
merged_dict[key].append(value)
else:
# 如果键不存在于merged_dict中,直接添加到merged_dict中
merged_dict[key] = [value]
# 输出合并后的结果
# print(merged_dict)
return merged_dict
val_loader = DataLoader(dataset, **val_params,collate_fn=collate_fn)
for i in val_loader:
print(i)
break
train_batch.py
# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
# from dataset_batch_en import QADataset
# from dataset_rerank import QADataset
# from dataset_rerank_en_gpt import QADataset
from dataset_rerank_en import QADataset
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm
import torch
import os, time, sys
import numpy as np
from modeling_RALLM import RALLM
import argparse
import deepspeed
from torch.nn.parallel import DataParallel
# 设置CUDA设备可见性,例如仅使用第一个GPU
# os.environ["CUDA_VISIBLE_DEVICES"] = "3"
parser = argparse.ArgumentParser()
parser.add_argument("--is_compress", default=True, type=bool)
parser.add_argument("--compressibility_factor", default=0, type=float,dest="0-1")
parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str)
parser.add_argument("--output", default="output_english_longformer_rerank100k_2", type=str)
parser.add_argument("--encoder", default="longformer", type=str)
parser.add_argument("--query_tokens", default=98, type=int)
parser.add_argument("--load_path", default="output_english_longformer_msmarco2019/checkpoint-87500", type=str)
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--num_train_epochs", default=10, type=int)
parser.add_argument("--learning_rate", default=5e-3, type=int)
parser.add_argument("--weight_decay", default=0.005, type=int)
parser.add_argument("--per_device_train_batch_size", default=6, type=int)
parser.add_argument("--max_length", default=4096, type=int)
parser.add_argument("--use_lora", default=True, type=bool)
parser.add_argument("--use_lora_gpt2", default=False, type=bool)
parser.add_argument("--train_dataset", default="data/qa_3w_summary.csv", type=str)
parser.add_argument("--epochs", default=1, type=int)
parser.add_argument("--batch_size", default=1, type=int)
args = parser.parse_args()
def train(epoch, model, loader, optimizer,scaler, gradient_accumulation_steps,model_output_dir):
model.train()
time1 = time.time()
losses = []
train_bar = tqdm(loader,total=len(loader))
for index, data in enumerate(train_bar):
optimizer.zero_grad()
with torch.autocast(device_type="cuda",dtype=torch.float16):
# print(data)
outputs = model(model,**data)
loss = outputs.loss
# 反向传播,计算当前梯度
loss.requires_grad_(True)
losses.append(loss.item())
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if (index+1) % 5000 == 0:
model_output_dir_ = os.path.join(model_output_dir,f"epoch{epoch}")
model_save_path = os.path.join(model_output_dir_,"index_{}".format(index))
if os.path.exists(model_save_path):
pass
else:
os.makedirs(model_save_path)
torch.save(model.state_dict(), os.path.join(model_save_path,"LLM_model_{:.6f}.pth".format(np.mean(losses))))
train_bar.set_description("epoch:{} idx:{} loss:{:.6f}".format(epoch,index,np.mean(losses)))
def validate( model, loader):
model.eval()
predictions = []
actuals = []
with torch.no_grad():
with torch.autocast(device_type="cuda",dtype=torch.float16):
for _, data in enumerate(tqdm(loader, file=sys.stdout, desc="Validation Data")):
text = data["text_input"]
context = data["context"]
label = data["label"]
print(text)
print("len context:",len(context))
for text_,context_ in zip(text,context):
preds = model.generate(
text=text_,context = [context_]
)
print(preds)
print(label)
predictions.append(preds)
actuals.extend(label)
return predictions, actuals
def main():
epochs = args.epochs
batch_size = args.batch_size
lr = 1e-5
gradient_accumulation_steps = 16
model_output_dir = args.output
# train_path = "qa_3w_summary.csv"
train_path = args.train_dataset
val_path = args.train_dataset
device = torch.device(f"cuda:{args.local_rank}")
model = RALLM(args)
model = model.to(device)
if args.use_lora:
print("使用lora训练模型"+"*"*10)
from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["W_pack",],
inference_mode=False,
r=256,
lora_alpha=512,
lora_dropout=0.1,
)
model.LLM_model.enable_input_require_grads()
model.LLM_model = get_peft_model(model.LLM_model, peft_config)
if args.use_lora_gpt2:
print("使用lora训练模型"+"*"*10)
from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["wte","c_attn",],
inference_mode=False,
r=256,
lora_alpha=512,
lora_dropout=0.1,
)
# model.LLM_model.enable_input_require_grads()
model.context_encoder = get_peft_model(model.context_encoder, peft_config)
print(model)
torch.cuda.empty_cache() # 释放显存
if args.load_path:
base_load_path = args.load_path
# 列出所有分块模型参数文件的文件名
if base_load_path.endswith(".pth"):
state_dict = torch.load(base_load_path,map_location=device)
else:
file_list = ['pytorch_model.bin']
# 创建一个空的模型状态字典
state_dict = {}
# 遍历所有分块文件并加载它们
for file_name in file_list:
# 加载单个分块文件的模型参数
part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=device)
# 将加载的模型参数合并到总的模型状态字典中
state_dict.update(part_state_dict)
# 将合并后的模型状态字典加载到模型中
print("state_dict:")
print(state_dict.keys())
model.load_state_dict(state_dict,strict=False)
for param in model.context_encoder.parameters():
param.requires_grad = False
# layers_to_modify = [30,31,32,33,34,35]
# # Iterate over all named parameters in the model
# for name, param in model.context_encoder.named_parameters():
# # Check if the parameter belongs to the specified layers
# if any(f"context_encoder.h.{layer}." in name for layer in layers_to_modify):
# # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
# param.requires_grad = True # or False if you want to freeze the layer
for param in model.ln_features.parameters():
param.requires_grad = True
for param in model.Qformer.parameters():
param.requires_grad = True
# 遍历每一层并冻结参数
# for param in model.LLM_model.parameters():
# param.requires_grad = False
# 冻结除了lora_A和lora_B以外的所有层
# trained = []
# untrained = []
# for name, param in model.LLM_model.named_parameters():
# # if 'v_proj.lora_A' in name or 'v_proj.lora_B' in name or 'q_proj.lora_B' in name or 'q_proj.lora_B' in name:
# if 'lora_A' in name or 'lora_B' in name:
# param.requires_grad = True
# trained.append(name)
# else:
# param.requires_grad = False
# untrained.append(name)
# Print trainable and non-trainable parameters
trainable_params = []
non_trainable_params = []
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(name)
else:
non_trainable_params.append(name)
print("Trainable Parameters:")
print("\n".join(trainable_params))
print("\nNon-Trainable Parameters:")
print("\n".join(non_trainable_params))
# setup peft
# peft_config = LoraConfig(
# task_type=TaskType.CAUSAL_LM,
# target_modules=["q_proj","v_proj"], #W_pack. query_key_value
# inference_mode=False,
# r=lora_rank,
# lora_alpha=lora_alpha,
# lora_dropout=0.1
# )
# model = get_peft_model(model, peft_config)
# model.is_parallelizable = True
# model.model_parallel = True
# model.print_trainable_parameters()
# 转为半精度
# model.LLM_model = model.LLM_model.half()
# model.float()
scaler = torch.cuda.amp.GradScaler()
def collate_fn(batch):
"""
对batch数据进行处理
:param batch: [一个getitem的结果,getitem的结果,getitem的结果]
:return: 元组
"""
# 初始化一个空字典来存储合并后的结果
merged_dict = {}
# 遍历列表中的每个字典
for d in batch:
# 遍历每个字典中的键值对
for key, value in d.items():
# 如果键已经存在于merged_dict中,将值合并为一个字符串,用逗号分隔
if key in merged_dict:
merged_dict[key].append(value)
else:
# 如果键不存在于merged_dict中,直接添加到merged_dict中
merged_dict[key] = [value]
# 输出合并后的结果
# print(merged_dict)
return merged_dict
print("Start Load Train Data...")
train_params = {
"batch_size": batch_size,
"shuffle": True,
"num_workers": 0,
}
training_set = QADataset(train_path,train=True)
training_loader = DataLoader(training_set, **train_params,collate_fn=collate_fn)
print("Start Load Validation Data...")
val_params = {
"batch_size": batch_size,
"shuffle": False,
"num_workers": 0,
}
val_set = QADataset(val_path,train=False)
val_loader = DataLoader(val_set, **val_params,collate_fn=collate_fn)
# optimizer = torch.optim.AdamW([{'params': model.bert_encoder.parameters(), 'lr': 1e-5},
# {'params': model.Qformer.parameters(), 'lr': 1e-3},
# {'params': model.ln_features.parameters(), 'lr': 1e-3},
# {'params': model.internlm_model.parameters(), 'lr': 1e-5},
# {'params': query_tokens_clone, 'lr': 1e-3}] #
# )
optimizer = torch.optim.AdamW([{'params': model.parameters(), 'lr': lr}])
# device_ids = [1,3,6,7]
# model = DataParallel(model, device_ids=device_ids)
print("Start Training...")
for epoch in range(epochs):
# train(epoch, model, training_loader, optimizer,scaler, gradient_accumulation_steps,model_output_dir)
# print("Save Model To ", 加)
# model.save_pretrained(model_output_dir)
# 验证
# print("Start Validation...")
with torch.no_grad():
predictions, actuals = validate(model, val_loader)
# 验证结果存储
final_df = pd.DataFrame({"Generated Text": predictions, "Actual Text": actuals})
val_data_path = os.path.join(model_output_dir, f"predictions_{epoch}.csv")
final_df.to_csv(val_data_path)
print("Validation Data To ", val_data_path)
if __name__ == '__main__':
main()
test_chat.py
# -*- coding: utf-8 -*-
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer,AutoModelForMaskedLM,BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline,AutoModelForCausalLM,AutoConfig
from transformers.generation.utils import GenerationConfig
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm
import torch
import os, time, sys
import numpy as np
from modeling_RALLM import RALLM
import argparse
from torch import autocast
parser = argparse.ArgumentParser()
parser.add_argument("--is_compress", default=False, type=bool)
parser.add_argument("--compressibility_factor", default=0.1, type=float,dest="0-1")
parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str)
parser.add_argument("--output", default="output_corpus_2", type=str)
parser.add_argument("--encoder", default="gpt2_large", type=str)
parser.add_argument("--query_tokens", default=98, type=int)
parser.add_argument("--load_path", default="output_corpus_lora/checkpoint-200004", type=str)
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--num_train_epochs", default=10, type=int)
parser.add_argument("--learning_rate", default=5e-5, type=int)
parser.add_argument("--weight_decay", default=0.005, type=int)
parser.add_argument("--per_device_train_batch_size", default=6, type=int)
parser.add_argument("--max_length", default=4096, type=int)
parser.add_argument("--use_lora", default=False, type=bool)
parser.add_argument("--use_lora_gpt2", default=True, type=bool)
args = parser.parse_args()
def chat( model):
while True:
context1 = input("输入context:")
context2 = input("输入context2:")
# context = """NVIDIA的A6000显卡是一款面向专业领域的高性能显卡。关于它的双精度(Double Precision)、单精度(Single Precision)和半精度(Half Precision)的算力,我们可以参考官方提供的规格参数。截至我最后更新的信息(2023年4月),以下是A6000显卡的相关算力数据:双精度(Double Precision): A6000显卡在双精度计算方面的性能通常不如单精度和半精度,因为双精度计算需要更多的计算资源和带宽。具体数值因显卡的不同批次和制造工艺的微小差异可能有所不同。单精度(Single Precision): A6000在单精度计算方面的性能通常很高,适合于大多数图形处理和一些科学计算任务。单精度计算是大多数显卡的主要优势。半精度(Half Precision): 半精度计算主要用于某些机器学习和深度学习应用,能提供更高的吞吐量。A6000显卡在半精度计算方面的性能通常很高。
# """
flag_context = "<context> "*2
text = f'<reserved_106>请复述这段被压缩的内容{flag_context} <reserved_107>'
data = {"context":[[context1,context2]],"text_input":text}
model.eval()
with torch.no_grad():
with autocast(device_type="cuda",dtype=torch.float16):
text = data["text_input"]
context = data["context"]
preds = model.generate(
text=text,context = context
)
print("输出:",preds)
def main():
model = RALLM(args) # 释放不再需要的模型
device = torch.device(f"cuda:{args.local_rank}")
model.to(device)
if args.use_lora:
from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["W_pack",],
inference_mode=False,
r=256,
lora_alpha=512,
lora_dropout=0.1,
)
model.LLM_model.enable_input_require_grads()
model.LLM_model = get_peft_model(model.LLM_model, peft_config)
if args.use_lora_gpt2:
print("使用lora训练模型"+"*"*10)
from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["c_attn",],
inference_mode=False,
r=64,
lora_alpha=256,
lora_dropout=0.1,
)
# model.LLM_model.enable_input_require_grads()
model.context_encoder = get_peft_model(model.context_encoder, peft_config)
print(model)
base_load_path = "output_qa3w_lora_gpt2_base_corpus"
# 列出所有分块模型参数文件的文件名
file_list = ['pytorch_model.bin']
# 创建一个空的模型状态字典
state_dict = {}
# 遍历所有分块文件并加载它们
for file_name in file_list:
# 加载单个分块文件的模型参数
part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=f"cuda:{args.local_rank}")
# 将加载的模型参数合并到总的模型状态字典中
state_dict.update(part_state_dict)
# 将合并后的模型状态字典加载到模型中
model.load_state_dict(state_dict)
model.LLM_model.generation_config = GenerationConfig.from_pretrained(base_load_path)
# 加载模型的参数
# load_path = '/data2/xinyuuliu/InternLM-XComposer/output12/epoch9/index_29999/LLM_model_0.109371.pth'
# checkpoint = torch.load(load_path,map_location="cuda:0") #,map_location="cuda:3"
# # 将参数加载到模型中
# model.load_state_dict(checkpoint)
# 转为半精度
# model.LLM_model = model.LLM_model.half()
model = model.half()
# model.float()
chat(model)
if __name__ == '__main__':
main()
ds_config.json
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 10,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
fine-tune.py
# This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca.
from dataclasses import dataclass, field
import json
import math
import logging
import os
import random
from typing import Dict, Optional, List
import torch
from torch.utils.data import Dataset
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
import transformers
from transformers import Trainer, deepspeed
from transformers.trainer_pt_utils import LabelSmoother
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate.utils import DistributedType
from torchvision import transforms
from typing import Dict, Optional, Sequence, List
from modeling_RALLM import RALLM
# from dataset_batch import QADataset
from dataset_rerank_en import QADataset
# from dataset_rerank_en_gpt import QADataset
# from dataset_rerank import QADataset
import argparse
from transformers import AutoModel, AutoTokenizer,AutoModelForMaskedLM,BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline,AutoModelForCausalLM,AutoConfig
from torch.optim import AdamW
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
parser = argparse.ArgumentParser()
parser.add_argument("--is_compress", default=True, type=bool)
parser.add_argument("--compressibility_factor", default=0.1, type=float,dest="0-1")
parser.add_argument("--LLM_model", default="Baichuan2_7B", type=str)
parser.add_argument("--output", default="output_corpus_lora2", type=str)
parser.add_argument("--encoder", default="gpt2_large", type=str)
parser.add_argument("--query_tokens", default=98, type=int)
parser.add_argument("--load_path", default="output_english_longformer_rerank100k/checkpoint-112356", type=str)
parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument("--num_train_epochs", default=10, type=int)
parser.add_argument("--learning_rate", default=5e-5, type=float)
parser.add_argument("--weight_decay", default=0.01, type=float)
parser.add_argument("--per_device_train_batch_size", default=6, type=int)
parser.add_argument("--max_length", default=4096, type=int)
parser.add_argument("--use_lora", default=True, type=bool)
parser.add_argument("--use_lora_gpt2", default=False, type=bool)
parser.add_argument("--train_dataset", default="data/qa_3w_summary.csv", type=str)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.local_rank)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
# cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
local_rank: int = field(default=None)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
"""
对batch数据进行处理
:param batch: [一个getitem的结果,getitem的结果,getitem的结果]
:return: 元组
"""
# 初始化一个空字典来存储合并后的结果
merged_dict = {}
# 遍历列表中的每个字典
for d in instances:
# 遍历每个字典中的键值对
for key, value in d.items():
# 如果键已经存在于merged_dict中,将值合并为一个字符串,用逗号分隔
if key in merged_dict:
merged_dict[key].append(value)
else:
# 如果键不存在于merged_dict中,直接添加到merged_dict中
merged_dict[key] = [value]
# 输出合并后的结果
# print(merged_dict)
return merged_dict
def train():
global model
train_path = args.train_dataset
# train_path = "data/news_summary_30w.csv"
# val_path = "QA_5000_summary.csv"
device_map = None
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
torch.cuda.device(0)
torch.cuda.empty_cache() # 释放显存
# init model and tokenizer
model = RALLM(args) # 释放不再需要的模型
device = torch.device(f"cuda:{args.local_rank}")
model.to(device)
# torch.cuda.device(0)
torch.cuda.empty_cache() # 释放显存
if args.use_lora:
print("使用lora训练模型"+"*"*10)
from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["W_pack",],
inference_mode=False,
r=256,
lora_alpha=512,
lora_dropout=0.1,
)
model.LLM_model.enable_input_require_grads()
model.LLM_model = get_peft_model(model.LLM_model, peft_config)
if args.use_lora_gpt2:
print("使用lora训练模型"+"*"*10)
from peft import LoraConfig, TaskType, get_peft_model,PeftModel,AutoPeftModelForCausalLM
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
# target_modules=["wte","c_attn",],
target_modules=["query","key","value","query_global","key_global","value_global"],
inference_mode=False,
r=128,
lora_alpha=512,
lora_dropout=0.1,
)
# model.LLM_model.enable_input_require_grads()
model.context_encoder = get_peft_model(model.context_encoder, peft_config)
print(model)
for param in model.context_encoder.parameters():
param.requires_grad = False
# layers_to_modify = [27,28,29,30,31,32,33,34, 35]
# # Iterate over all named parameters in the model
# for name, param in model.context_encoder.named_parameters():
# # Check if the parameter belongs to the specified layers
# if any(f"h.{layer}." in name for layer in layers_to_modify):
# # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
# param.requires_grad = True # or False if you want to freeze the layer
# # if f"ln_f" in name:
# # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
# # param.requires_grad = True # or False if you want to freeze the layer
# if f"wte" in name:
# # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
# param.requires_grad = True # or False if you want to freeze the layer
# if f"wpe" in name:
# # # Set requires_grad to True or False depending on whether you want the parameter to be trainable or not
# param.requires_grad = True # or False if you want to freeze the layer
for param in model.ln_features.parameters():
param.requires_grad = True
# 遍历每一层并冻结参数
# for param in model.LLM_model.parameters():
# param.requires_grad = False
# 冻结除了lora_A和lora_B以外的所有层
trained = []
untrained = []
# for name, param in model.LLM_model.named_parameters():
# # if 'v_proj.lora_A' in name or 'v_proj.lora_B' in name or 'q_proj.lora_B' in name or 'q_proj.lora_B' in name:
# # if 'lora_A' in name or 'lora_B' in name or "layers.30" in name or "layers.31" in name or "embed_tokens" in name:
# if 'lora_A' in name or 'lora_B' in name or "embed_tokens" in name:
# param.requires_grad = True
# trained.append(name)
# else:
# param.requires_grad = False
# untrained.append(name)
# print("可训练的大模型层",trained)
# print("不可训练的大模型层",untrained)
# Print trainable and non-trainable parameters
trainable_params = []
non_trainable_params = []
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(name)
else:
non_trainable_params.append(name)
print("Trainable Parameters:")
print("\n".join(trainable_params))
print("\nNon-Trainable Parameters:")
print("\n".join(non_trainable_params))
if args.load_path:
base_load_path = args.load_path
# 列出所有分块模型参数文件的文件名
if base_load_path.endswith(".pth"):
state_dict = torch.load(base_load_path,map_location=device)
else:
file_list = ['pytorch_model.bin']
# 创建一个空的模型状态字典
state_dict = {}
# 遍历所有分块文件并加载它们
for file_name in file_list:
# 加载单个分块文件的模型参数
part_state_dict = torch.load(os.path.join(base_load_path,file_name),map_location=device)
# 将加载的模型参数合并到总的模型状态字典中
state_dict.update(part_state_dict)
# 将合并后的模型状态字典加载到模型中
print("state_dict:")
print(state_dict.keys())
model.load_state_dict(state_dict,strict=False)
# # 分离 model.Qformer 的参数和其他所有参数
# qformer_params = set(model.Qformer.parameters())
# other_params = [p for p in model.parameters() if p not in qformer_params]
# # 创建参数组
# param_groups = [
# {'params': list(qformer_params), 'lr': 1e-3},
# {'params': other_params, 'lr': 1e-5}
# ]
# 使用参数组创建 AdamW 优化器
# optimizer = AdamW(param_groups)
training_set = QADataset(train_path,train=True)
# val_set = QADataset(val_path,train=False)
print(training_set[0])
# 设置训练参数
training_args = TrainingArguments(
local_rank=args.local_rank,
output_dir=args.output, # 输出目录
num_train_epochs=args.num_train_epochs, # 训练轮数
per_device_train_batch_size=args.per_device_train_batch_size, # 每个设备的批大小
warmup_steps=500, # 预热步骤
weight_decay=0.01, # 权重衰减
logging_dir='./logs', # 日志目录
deepspeed = "ds_config.json",
gradient_accumulation_steps = 1 ,
save_strategy = "epoch" ,
learning_rate = 5e-5 ,
# lr_scheduler_type='linear',
# logging_steps= 100,
)
data_collator = DataCollatorForSupervisedDataset()
# Start trainner
trainer = Trainer(
model = model,
tokenizer = model.LLM_tokenizer,
train_dataset=training_set,
# eval_dataset=val_set,
data_collator=data_collator,
args = training_args,
# optimizers=(optimizer, None) # 自定义优化器
)
trainer.train()
trainer.save_state()
trainer.save_model(output_dir=args.output)
if __name__ == "__main__":
train()
#https://arxiv.org/pdf/2102.05951.pdf
fine-tune.sh
hostfile=""
# deepspeed --include localhost:1,2,3 --hostfile=$hostfile fine-tune.py \
# --report_to "none" \
# --data_path "/data1/xinyuuliu/qa_data/professional_data/train_二阶段.json" \
# --model_name_or_path "/data1/xinyuuliu/Baichuan2-13B-Chat" \
# --output_dir "output_lora3_1_2" \
# --model_max_length 4000\
# --num_train_epochs 10 \
# --per_device_train_batch_size 4 \
# --gradient_accumulation_steps 1 \
# --save_strategy epoch \
# --learning_rate 2e-4 \
# --lr_scheduler_type constant \
# --adam_beta1 0.9 \
# --adam_beta2 0.98 \
# --adam_epsilon 1e-8 \
# --max_grad_norm 1.0 \
# --weight_decay 1e-4 \
# --warmup_ratio 0.0 \
# --logging_steps 1 \
# --gradient_checkpointing True \
# --deepspeed ds_config.json \
# --bf16 True \
# --tf32 True \
# --use_lora True \
# --load_lora_path /data1/xinyuuliu/Baichuan2-main/fine-tune/output_lora3_1/checkpoint-8260
# --use_NEFT True
# --use_frozen True
# export CUDA_LAUNCH_BLOCKING=1
# CUDA_VISIBLE_DEVICES=“2,3,4,5,6,7”
deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_port 29501 --hostfile=$hostfile fine-tune.py \
--encoder longformer \
--query_tokens 32 \
--output output_english_longformer_msmarco2019\
--num_train_epochs 20 \
--per_device_train_batch_size 1 \
# --load_path /data2/xinyuuliu/Baichuan2_qformer_bert/output_30w/checkpoint-22488 \