tBERT部分代码 (自学用)
tbert.py
from src.loaders.load_data import load_data
from src.models.base_model_bert import model,test_opt
import argparse
# run tbert with different learning rates on a certain dataset
# example usage: python src/experiments/tbert.py -learning_rate 5e-05 -gpu 0 -topic_type ldamallet -topic word -dataset MSRP --debug
#命令行
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument('-dataset', action="store", dest="dataset", type=str, default='MSRP')
parser.add_argument('-learning_rate', action="store", dest="learning_rate", type=str, default='3e-5') # learning_rates = [5e-5, 2e-5, 4e-5, 3e-5]
parser.add_argument('-layers', action="store", dest="hidden_layers", type=str, default='0')
parser.add_argument('-topic', action="store", dest="topic", type=str, default='word')
parser.add_argument('-gpu', action="store", dest="gpu", type=int, default=-1)
parser.add_argument("--speedup_new_layers",type="bool",nargs="?",const=True,default=False,help="Use 100 times higher learning rate for new layers.")
parser.add_argument("--debug",type="bool",nargs="?",const=True,default=False,help="Try to use small number of examples for troubleshooting")
parser.add_argument("--train_longer",type="bool",nargs="?",const=True,default=False,help="Train for 9 epochs")
parser.add_argument("--early_stopping",type="bool",nargs="?",const=True,default=False)
parser.add_argument("--unk_topic_zero",type="bool",nargs="?",const=True,default=False)
parser.add_argument('-seed', action="store", dest="seed", type=str, default='fixed')
parser.add_argument('-topic_type', action="store", dest="topic_type", type=str, default='ldamallet')
FLAGS, unparsed = parser.parse_known_args()
# sanity check command line arguments
if len(unparsed)>0:
parser.print_help()
raise ValueError('Unidentified command line arguments passed: {}\n'.format(str(unparsed))) #在程序的指定位置手动抛出一个异常
#raise 异常类名称(描述信息):在引发指定类型的异常的同时,附带异常的描述信息。
# setting model options based on flags
#设置并检查dataset
dataset = FLAGS.dataset
assert dataset in ['MSRP','Semeval_A','Semeval_B','Semeval_C','Quora'] #检查是否在这些数据集中,否则提前报错
#设置并检查hidden_layer
hidden_layers = [int(h) for h in FLAGS.hidden_layers.split(',')]
for h in hidden_layers:
assert h in [0,1,2]
#设置并检查topic (文档级还是词语级 word or doc)
topics = FLAGS.topic.split(',')
for t in topics:
assert t in ['word','doc'] #检查 topic
#初始化需要的表及参数
priority = []
todo = []
last = []
stopping_criterion = None #'F1'
patience = None
batch_size = 32 # standard minibatch size
#根据数据集设置任务、batch_size、num_topics、alpha
if 'Semeval' in dataset:
dataset, task = dataset.split('_')
subsets = ['train_large', 'test2016', 'test2017']
if task in ['A']:
batch_size = 16 # need smaller minibatch to fit on GPU due to long sentences
num_topics = 70
if FLAGS.topic_type=='gsdmm':
alpha = 0.1
else:
alpha = 50
elif task == 'B':
num_topics = 80
if FLAGS.topic_type=='gsdmm':
alpha = 0.1
else:
alpha = 10
elif task == 'C':
batch_size = 16 # need smaller minibatch to fit on GPU due to long sentences
num_topics = 70
if FLAGS.topic_type=='gsdmm':
alpha = 0.1
else:
alpha = 10
else:
task = 'B'
if dataset== 'Quora':
subsets = ['train', 'dev', 'test']
num_topics = 90
if FLAGS.topic_type=='gsdmm':
alpha = 0.1
else:
alpha = 1
task = 'B'
else:
subsets = ['train', 'dev', 'test'] # MSRP
num_topics = 80
if FLAGS.topic_type=='gsdmm':
alpha = 0.1
else:
alpha = 1
task = 'B'
if FLAGS.debug:
max_m = 100
else:
max_m = None
if FLAGS.train_longer:
epochs = 9
predict_every_epoch = True
else:
epochs = 3
predict_every_epoch = False
if FLAGS.early_stopping:
patience = 2
stopping_criterion = 'F1'
try:
seed = int(FLAGS.seed)
except:
seed = None
if FLAGS.unk_topic_zero:
unk_topic = 'zero'
else:
unk_topic = 'uniform'
#根据是文档级还是词汇级输入
for topic_scope in topics:
for hidden_layer in hidden_layers:
opt = {
'dataset': dataset,
'datapath': 'data/',
'model': 'bert_simple_topic', #
'bert_update':True,
'bert_cased':False,
'tasks': [task],
'subsets': subsets,
'seed':seed,
'minibatch_size': batch_size,
'L2': 0,
'max_m': max_m,
'load_ids': True,
'topic':topic_scope, #
'topic_update':False, #
'num_topics':num_topics, #
'topic_alpha':alpha, #
'unk_topic': unk_topic,
'topic_type':FLAGS.topic_type,
'unk_sub': False,
'padding': False,
'simple_padding': True,
'learning_rate': float(FLAGS.learning_rate),
'num_epochs': epochs,
'hidden_layer':hidden_layer,
'sparse_labels': True,
'max_length': 'minimum',
'optimizer': 'Adam',
'dropout':0.1,
'gpu': FLAGS.gpu,
'speedup_new_layers':FLAGS.speedup_new_layers,
'predict_every_epoch': predict_every_epoch,
'stopping_criterion':stopping_criterion,
'patience':patience
}
todo.append(opt) #追加到上面定义的todo列表
tasks = todo #复制一份
if __name__ == '__main__':
for i,opt in enumerate(tasks):
print('Starting experiment {} of {}'.format(i+1,len(tasks)))
l_rate = str(opt['learning_rate']).replace('-0','-') #把学习率中没用的0去掉
if FLAGS.speedup_new_layers:
log = 'tbert_{}_seed_speedup_new_layers.json'.format(str(seed))
elif FLAGS.train_longer:
log = 'tbert_{}_seed_train_longer.json'.format(str(seed))
else:
log = 'tbert_{}_seed.json'.format(str(seed))
if FLAGS.early_stopping:
log = log.replace('.json','_early_stopping.json')
print(log)
print(opt)
test_opt(opt) #检查是否输入可选命令
data = load_data(opt, cache=True, write_vocab=False) #加载主题信息
if FLAGS.debug:
# print(data[''])
print(data['E1'][0].shape)
print(data['E1'][1].shape)
print(data['E1'][2].shape)
print(data['E1_mask'][0])
print(data['E1_seg'][0])
opt = model(data, opt, logfile=log, print_dim=True)
load_data.py
import importlib
import os
import pickle
import numpy as np
from src.loaders.Quora.build import build
from src.loaders.augment_data import create_large_train, double_task_training_data
from src.preprocessing.Preprocessor import Preprocessor, get_onehot_encoding, reduce_embd_id_len
from src.topic_model.topic_loader import load_document_topics, load_word_topics
from src.topic_model.topic_visualiser import read_topic_key_table
#返回结合数据集、子集、任务的文件名称 filenames.append(prefix+subsets+'_'+task) eg:m_train_B m_dev_B m_test_B
def get_filenames(opt):
filenames = [] # name of cache files
for s in opt['subsets']:
for t in opt['tasks']:
prefix = ''
if opt['dataset'] == 'Quora':
if s.startswith('p_'):
prefix = ''
else:
prefix = 'q_'
if opt['dataset'] == 'PAWS':
prefix = 'p_'
if opt['dataset'] == 'MSRP':
prefix = 'm_'
filenames.append(prefix+s+'_'+t) #m_train_B m_dev_B m_test_B
return filenames
#返回文件路径os.path.join(opt['datapath'], opt['dataset'], name + '.txt') eg:data/MSRP/m_train_B.text
def get_filepath(opt):
filepaths = []
for name in get_filenames(opt):
if 'quora' in name: #为什么单独?
filepaths.append(os.path.join(opt['datapath'], 'Quora', name + '.txt'))
print('quora in filename')
else:
filepaths.append(os.path.join(opt['datapath'], opt['dataset'], name + '.txt'))
return filepaths #data/MSRP/m_train_B.text
#加载文件,返回列表return (ID1, ID2, D1, D2, L) 文档号,文档内容,标签
def load_file(filename,onehot=True):
"""
Reads file and returns tuple of (ID1, ID2, D1, D2, L) if ids=False
"""
# todo: return dictionary
ID1 = []
ID2 = []
D1 = []
D2 = []
L = []
with open(filename,'r',encoding='utf-8') as read:
for i,line in enumerate(read):
if not len(line.split('\t'))==5:
print(line.split('\t'))
id1, id2, d1, d2, label = line.rstrip().split('\t')
ID1.append(id1)
ID2.append(id2)
D1.append(d1)
D2.append(d2)
if 's_' in filename:
if float(label)>=4:
label = 1
elif float(label)<4:
label = 0
else:
ValueError()
L.append(int(label))
L = np.array(L) #创建数组
# L = L.reshape(len(D1),1)
if onehot:
classes = L.shape[1] + 1
L = get_onehot_encoding(L)
print('Encoding labels as one hot vector.')
return (ID1, ID2, D1, D2, L)
#根据数据集设置两个句子s1,s2的最大长度
def get_dataset_max_length(opt):
'''
Determine maximum number of tokens in both sentences, as well as highes max length for current task
:param opt:
:return: [maximum length of sentence in tokens,should first sentence be shortened?]
'''
tasks = opt['tasks']
if opt['dataset'] in ['Quora','PAWS','GlueQuora']:
cutoff = opt.get('max_length', 24)
if cutoff == 'minimum':
cutoff = 24
s1_len, s2_len = cutoff, cutoff
elif opt['dataset']=='MSRP':
cutoff = opt.get('max_length', 40)
if cutoff == 'minimum':
cutoff = 40
s1_len, s2_len = cutoff, cutoff
elif 'B' in tasks:
cutoff = opt.get('max_length', 100)
if cutoff == 'minimum':
cutoff = 100
s1_len, s2_len = cutoff, cutoff
elif 'A' in tasks or 'C' in tasks:
cutoff = opt.get('max_length', 200)
if cutoff == 'minimum':
s1_len = 100
s2_len = 200
else:
s1_len, s2_len = cutoff,cutoff
return s1_len,s2_len,max([s1_len,s2_len])
#削减样例,保留前m个
def reduce_examples(matrices, m):
'''
Reduces the size of matrices
:param matrices:
:param m: maximum number of examples
:return:
'''
return [matrix[:m] for matrix in matrices] #matrix[:m]保留前m个
def create_missing_datafiles(opt,datafile,datapath):
if not os.path.exists(datapath) and 'large' in datafile:
create_large_train()
if not os.path.exists(datapath) and 'double' in datafile:
double_task_training_data()
if not os.path.exists(datapath) and 'quora' in datafile:
quora_opt = opt
quora_opt['dataset'] = 'Quora'
build(quora_opt)
#生成缓存文件夹 eg:data/cache/
def get_cache_folder(opt):
return opt['datapath'] + 'cache/'
#返回id1,id2,r1,r2,t1,t2,l
def load_cache_or_process(opt, cache, onehot):
ID1 = []
ID2 = []
R1 = []
R2 = []
T1 = []
T2 = []
L = []
filenames = get_filenames(opt) #m_train_B m_dev_B m_test_B
print(filenames)
filepaths = get_filepath(opt) #data/MSRP/m_train_B.text
print(filepaths)
for datafile,datapath in zip(filenames,filepaths): #zip()函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
create_missing_datafiles(opt,datafile,datapath) # if necessary
cache_folder = get_cache_folder(opt) #data/cache/
if not os.path.exists(cache_folder):