tBERT部分代码(自学用)

本文档包含tBERT模型的实现代码,包括tbert.py、load_data.py、bert_simple_topic.py和base_model_bert.py四个文件,以及TensorFlow助手函数tf_helpers。适合自学NLP和Python的读者。
摘要由CSDN通过智能技术生成

tBERT部分代码 (自学用)

tbert.py

from src.loaders.load_data import load_data
from src.models.base_model_bert import model,test_opt
import argparse

# run tbert with different learning rates on a certain dataset
# example usage: python src/experiments/tbert.py -learning_rate 5e-05 -gpu 0 -topic_type ldamallet -topic word -dataset MSRP --debug

#命令行
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument('-dataset', action="store", dest="dataset", type=str, default='MSRP')
parser.add_argument('-learning_rate', action="store", dest="learning_rate", type=str, default='3e-5') # learning_rates = [5e-5, 2e-5, 4e-5, 3e-5]
parser.add_argument('-layers', action="store", dest="hidden_layers", type=str, default='0')
parser.add_argument('-topic', action="store", dest="topic", type=str, default='word')
parser.add_argument('-gpu', action="store", dest="gpu", type=int, default=-1)
parser.add_argument("--speedup_new_layers",type="bool",nargs="?",const=True,default=False,help="Use 100 times higher learning rate for new layers.")
parser.add_argument("--debug",type="bool",nargs="?",const=True,default=False,help="Try to use small number of examples for troubleshooting")
parser.add_argument("--train_longer",type="bool",nargs="?",const=True,default=False,help="Train for 9 epochs")
parser.add_argument("--early_stopping",type="bool",nargs="?",const=True,default=False)
parser.add_argument("--unk_topic_zero",type="bool",nargs="?",const=True,default=False)
parser.add_argument('-seed', action="store", dest="seed", type=str, default='fixed')
parser.add_argument('-topic_type', action="store", dest="topic_type", type=str, default='ldamallet')

FLAGS, unparsed = parser.parse_known_args()

# sanity check command line arguments
if len(unparsed)>0:
    parser.print_help()
    raise ValueError('Unidentified command line arguments passed: {}\n'.format(str(unparsed))) #在程序的指定位置手动抛出一个异常
    #raise 异常类名称(描述信息):在引发指定类型的异常的同时,附带异常的描述信息。

# setting model options based on flags

#设置并检查dataset
dataset = FLAGS.dataset
assert dataset in ['MSRP','Semeval_A','Semeval_B','Semeval_C','Quora'] #检查是否在这些数据集中,否则提前报错

#设置并检查hidden_layer
hidden_layers = [int(h) for h in FLAGS.hidden_layers.split(',')]
for h in hidden_layers:
    assert h in [0,1,2]

#设置并检查topic (文档级还是词语级 word or doc)
topics = FLAGS.topic.split(',')
for t in topics:
    assert t in ['word','doc']  #检查 topic

#初始化需要的表及参数
priority = []
todo = []
last = []
stopping_criterion = None #'F1'
patience = None
batch_size = 32 # standard minibatch size

#根据数据集设置任务、batch_size、num_topics、alpha
if 'Semeval' in dataset:
    dataset, task = dataset.split('_')
    subsets = ['train_large', 'test2016', 'test2017']
    if task in ['A']:
        batch_size = 16 # need smaller minibatch to fit on GPU due to long sentences
        num_topics = 70
        if FLAGS.topic_type=='gsdmm':
            alpha = 0.1
        else:
            alpha = 50
    elif task == 'B':
        num_topics = 80
        if FLAGS.topic_type=='gsdmm':
            alpha = 0.1
        else:
            alpha = 10
    elif task == 'C':
        batch_size = 16 # need smaller minibatch to fit on GPU due to long sentences
        num_topics = 70
        if FLAGS.topic_type=='gsdmm':
            alpha = 0.1
        else:
            alpha = 10
else:
    task = 'B'
    if dataset== 'Quora':
        subsets = ['train', 'dev', 'test'] 
        num_topics = 90
        if FLAGS.topic_type=='gsdmm':
            alpha = 0.1
        else:
            alpha = 1
        task = 'B'
    else:
        subsets = ['train', 'dev', 'test'] # MSRP
        num_topics = 80
        if FLAGS.topic_type=='gsdmm':
            alpha = 0.1
        else:
            alpha = 1
        task = 'B'


if FLAGS.debug:
    max_m = 100
else:
    max_m = None

if FLAGS.train_longer:
    epochs = 9
    predict_every_epoch = True
else:
    epochs = 3
    predict_every_epoch = False

if FLAGS.early_stopping:
    patience = 2
    stopping_criterion = 'F1'

try:
    seed = int(FLAGS.seed)
except:
    seed = None

if FLAGS.unk_topic_zero:
    unk_topic = 'zero'
else:
    unk_topic = 'uniform'

#根据是文档级还是词汇级输入
for topic_scope in topics:
    for hidden_layer in hidden_layers:
        opt = {
   'dataset': dataset,
               'datapath': 'data/',
               'model': 'bert_simple_topic',        #
               'bert_update':True,
               'bert_cased':False,
               'tasks': [task],
               'subsets': subsets,
               'seed':seed,
               'minibatch_size': batch_size,
               'L2': 0,
               'max_m': max_m,
               'load_ids': True,
               'topic':topic_scope,     #
               'topic_update':False,        #
               'num_topics':num_topics,     #
               'topic_alpha':alpha,     #
               'unk_topic': unk_topic,
               'topic_type':FLAGS.topic_type,
               'unk_sub': False,
               'padding': False,
               'simple_padding': True,
               'learning_rate': float(FLAGS.learning_rate),
               'num_epochs': epochs,
               'hidden_layer':hidden_layer,
               'sparse_labels': True,
               'max_length': 'minimum',
               'optimizer': 'Adam',
               'dropout':0.1,
               'gpu': FLAGS.gpu,
               'speedup_new_layers':FLAGS.speedup_new_layers,
               'predict_every_epoch': predict_every_epoch,
               'stopping_criterion':stopping_criterion,
               'patience':patience
               }
        todo.append(opt)  #追加到上面定义的todo列表

tasks = todo        #复制一份

if __name__ == '__main__':

    for i,opt in enumerate(tasks):
        print('Starting experiment {} of {}'.format(i+1,len(tasks)))
        l_rate = str(opt['learning_rate']).replace('-0','-')   #把学习率中没用的0去掉
        if FLAGS.speedup_new_layers:
            log = 'tbert_{}_seed_speedup_new_layers.json'.format(str(seed))
        elif FLAGS.train_longer:
            log = 'tbert_{}_seed_train_longer.json'.format(str(seed))
        else:
            log = 'tbert_{}_seed.json'.format(str(seed))
        if FLAGS.early_stopping:
            log = log.replace('.json','_early_stopping.json')

        print(log)
        print(opt)
        test_opt(opt)       #检查是否输入可选命令
        data = load_data(opt, cache=True, write_vocab=False)        #加载主题信息
        if FLAGS.debug:
            # print(data[''])
            print(data['E1'][0].shape)
            print(data['E1'][1].shape)
            print(data['E1'][2].shape)

            print(data['E1_mask'][0])
            print(data['E1_seg'][0])
        opt = model(data, opt, logfile=log, print_dim=True)

load_data.py

import importlib
import os
import pickle
import numpy as np

from src.loaders.Quora.build import build
from src.loaders.augment_data import create_large_train, double_task_training_data
from src.preprocessing.Preprocessor import Preprocessor, get_onehot_encoding, reduce_embd_id_len
from src.topic_model.topic_loader import load_document_topics, load_word_topics
from src.topic_model.topic_visualiser import read_topic_key_table

#返回结合数据集、子集、任务的文件名称 filenames.append(prefix+subsets+'_'+task) eg:m_train_B  m_dev_B  m_test_B
def get_filenames(opt):
    filenames = [] # name of cache files
    for s in opt['subsets']:
        for t in opt['tasks']:
            prefix = ''
            if opt['dataset'] == 'Quora':
                if s.startswith('p_'):
                    prefix = ''
                else:
                    prefix = 'q_'
            if opt['dataset'] == 'PAWS':
                prefix = 'p_'
            if opt['dataset'] == 'MSRP':
                prefix = 'm_'
            filenames.append(prefix+s+'_'+t)  #m_train_B  m_dev_B  m_test_B
    return filenames

#返回文件路径os.path.join(opt['datapath'], opt['dataset'], name + '.txt') eg:data/MSRP/m_train_B.text
def get_filepath(opt):
    filepaths = []
    for name in get_filenames(opt):
        if 'quora' in name:     #为什么单独?
            filepaths.append(os.path.join(opt['datapath'], 'Quora', name + '.txt'))
            print('quora in filename')
        else:
            filepaths.append(os.path.join(opt['datapath'], opt['dataset'], name + '.txt'))
    return filepaths        #data/MSRP/m_train_B.text

#加载文件,返回列表return (ID1, ID2, D1, D2, L)  文档号,文档内容,标签
def load_file(filename,onehot=True):
    """
    Reads file and returns tuple of (ID1, ID2, D1, D2, L) if ids=False
    """
    # todo: return dictionary
    ID1 = []
    ID2 = []
    D1 = []
    D2 = []
    L = []
    with open(filename,'r',encoding='utf-8') as read:
        for i,line in enumerate(read):
            if not len(line.split('\t'))==5:
                print(line.split('\t'))
            id1, id2, d1, d2, label = line.rstrip().split('\t')
            ID1.append(id1)
            ID2.append(id2)
            D1.append(d1)
            D2.append(d2)
            if 's_' in filename:
                if float(label)>=4:
                    label = 1
                elif float(label)<4:
                    label = 0
                else:
                    ValueError()
            L.append(int(label))
    L = np.array(L)     #创建数组
    # L = L.reshape(len(D1),1)
    if onehot:
        classes = L.shape[1] + 1
        L = get_onehot_encoding(L)
        print('Encoding labels as one hot vector.')
    return (ID1, ID2, D1, D2, L)

#根据数据集设置两个句子s1,s2的最大长度
def get_dataset_max_length(opt):
    '''
    Determine maximum number of tokens in both sentences, as well as highes max length for current task
    :param opt: 
    :return: [maximum length of sentence in tokens,should first sentence be shortened?]
    '''
    tasks = opt['tasks']
    if opt['dataset'] in ['Quora','PAWS','GlueQuora']:
        cutoff = opt.get('max_length', 24)
        if cutoff == 'minimum':
            cutoff = 24
        s1_len, s2_len = cutoff, cutoff
    elif opt['dataset']=='MSRP':
        cutoff = opt.get('max_length', 40)
        if cutoff == 'minimum':
            cutoff = 40
        s1_len, s2_len = cutoff, cutoff
    elif 'B' in tasks:
        cutoff = opt.get('max_length', 100)
        if cutoff == 'minimum':
            cutoff = 100
        s1_len, s2_len = cutoff, cutoff
    elif 'A' in tasks or 'C' in tasks:
        cutoff = opt.get('max_length', 200)
        if cutoff == 'minimum':
            s1_len = 100
            s2_len = 200
        else:
            s1_len, s2_len = cutoff,cutoff
    return s1_len,s2_len,max([s1_len,s2_len])

#削减样例,保留前m个
def reduce_examples(matrices, m):
    '''
    Reduces the size of matrices
    :param matrices: 
    :param m: maximum number of examples
    :return: 
    '''
    return [matrix[:m] for matrix in matrices]  #matrix[:m]保留前m个

def create_missing_datafiles(opt,datafile,datapath):
    if not os.path.exists(datapath) and 'large' in datafile:
        create_large_train()
    if not os.path.exists(datapath) and 'double' in datafile:
        double_task_training_data()
    if not os.path.exists(datapath) and 'quora' in datafile:
        quora_opt = opt
        quora_opt['dataset'] = 'Quora'
        build(quora_opt)

#生成缓存文件夹 eg:data/cache/
def get_cache_folder(opt):
    return opt['datapath'] + 'cache/'

#返回id1,id2,r1,r2,t1,t2,l
def load_cache_or_process(opt, cache, onehot):
    ID1 = []
    ID2 = []
    R1 = []
    R2 = []
    T1 = []
    T2 = []
    L = []

    filenames = get_filenames(opt)      #m_train_B  m_dev_B  m_test_B
    print(filenames)

    filepaths = get_filepath(opt)       #data/MSRP/m_train_B.text
    print(filepaths)

    for datafile,datapath in zip(filenames,filepaths):          #zip()函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
        create_missing_datafiles(opt,datafile,datapath) # if necessary
        cache_folder = get_cache_folder(opt)    #data/cache/
        if not os.path.exists(cache_folder):    
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值