Attention is all you need源码学习1

Process.py


main函数中有关命令行输入的代码


def main():
    ''' Main function '''

    #1.命令行运行时需要传入的参数required=True为必须传入的
    parser = argparse.ArgumentParser()
    parser.add_argument('-train_src', required=True)
    parser.add_argument('-train_tgt', required=True)
    parser.add_argument('-valid_src', required=True)
    parser.add_argument('-valid_tgt', required=True)
    parser.add_argument('-save_data', required=True)
    parser.add_argument('-max_len', '--max_word_seq_len', type=int, default=50)
    parser.add_argument('-min_word_count', type=int, default=5)
    parser.add_argument('-keep_case', action='store_true')
    parser.add_argument('-share_vocab', action='store_true')
    parser.add_argument('-vocab', default=None)

    #2.解析命令行
    opt = parser.parse_args()
    #3.计算命令行输入句子序列长度=最大单词数+2,因为可能存在</s>
    opt.max_token_seq_len = opt.max_word_seq_len + 2 # include the <s> and </s>

main函数中training data处理部分

# Training set
    #4.read_instances_from_file(数据集的绝对路径、最长的句子词数,是否全是小写)
    train_src_word_insts = read_instances_from_file(
        opt.train_src, opt.max_word_seq_len, opt.keep_case)
    train_tgt_word_insts = read_instances_from_file(
        opt.train_tgt, opt.max_word_seq_len, opt.keep_case)

补充:函数read_instances_from_file()的主要功能是逐行读入目标文件的内容(文件中一行就是一个句子),上述代码做的是将每行的句子进行分词转换成一个词的列表,并将所有句子的词的列表组合成一个大的句子的列表,例如:
[[什么,?,宋慧乔,宋仲基,离婚,了,?],
[如果,是,我, 嫁给,他],
[一定,不会,离婚,的]
[喜欢,他,!]]
返回值就是列表,函数代码如下:

def read_instances_from_file(inst_file, max_sent_len, keep_case):
    ''' Convert file into word seq lists and vocab '''

    word_insts = []
    trimmed_sent_count = 0
    with open(inst_file) as f:
        for sent in f:
            if not keep_case:
                sent = sent.lower()
            words = sent.split()
            if len(words) > max_sent_len:
                trimmed_sent_count += 1
            word_inst = words[:max_sent_len]

            if word_inst:
                word_insts += [[Constants.BOS_WORD] + word_inst + [Constants.EOS_WORD]]
            else:
                word_insts += [None]

    print('[Info] Get {} instances from {}'.format(len(word_insts), inst_file))

    if trimmed_sent_count > 0:
        print('[Warning] {} instances are trimmed to the max sentence length {}.'
              .format(trimmed_sent_count, max_sent_len))

    return word_insts

    #5.要求数据集数据条数=标签集数据条数,若不同则规范成相同条数
    if len(train_src_word_insts) != len(train_tgt_word_insts):
        print('[Warning] The training instance count is not equal.')
        min_inst_count = min(len(train_src_word_insts), len(train_tgt_word_insts))#取两者长度短的那个
        train_src_word_insts = train_src_word_insts[:min_inst_count]#使数据集条数为最短长度
        train_tgt_word_insts = train_tgt_word_insts[:min_inst_count]#使标签集条数为最短长度

    #- Remove empty instances
    #6.清洗不合法的数据和标签
    train_src_word_insts, train_tgt_word_insts = list(zip(*[
        (s, t) for s, t in zip(train_src_word_insts, train_tgt_word_insts) if s and t]))

补充:
1.zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同。
2.zip(* )为解压,返回二维矩阵式
3.for x, y in zip(list_1, list_2)用序列解包同时遍历多个序列
所以这句话的意思是:s为train_src_word_insts中的一句,t为train_tgt_word_insts中的一句,当s与t均不为空时,才将元组(s,t)解压到list列表中吗?以此清洗了空语句或空标签?

main函数中validation data处理部分

同上

 # Validation set
    #7.验证集,将数据转列表
    valid_src_word_insts = read_instances_from_file(
        opt.valid_src, opt.max_word_seq_len, opt.keep_case)
    valid_tgt_word_insts = read_instances_from_file(
        opt.valid_tgt, opt.max_word_seq_len, opt.keep_case)

    #8.规范验证集数据条数=验证集标签数据条数
    if len(valid_src_word_insts) != len(valid_tgt_word_insts):
        print('[Warning] The validation instance count is not equal.')
        min_inst_count = min(len(valid_src_word_insts), len(valid_tgt_word_insts))
        valid_src_word_insts = valid_src_word_insts[:min_inst_count]
        valid_tgt_word_insts = valid_tgt_word_insts[:min_inst_count]

    #- Remove empty instances
    #9.清洗验证集数据
    valid_src_word_insts, valid_tgt_word_insts = list(zip(*[
        (s, t) for s, t in zip(valid_src_word_insts, valid_tgt_word_insts) if s and t]))

main函数中构建词表部分

    # Build vocabulary
    #构建词表
    #10.可选参数处理,未用到
    if opt.vocab:
        predefined_data = torch.load(opt.vocab)
        assert 'dict' in predefined_data

        print('[Info] Pre-defined vocabulary found.')
        src_word2idx = predefined_data['dict']['src']
        tgt_word2idx = predefined_data['dict']['tgt']
    else:
        if opt.share_vocab:
            print('[Info] Build shared vocabulary for source and target.')
            word2idx = build_vocab_idx(
                train_src_word_insts + train_tgt_word_insts, opt.min_word_count)
            src_word2idx = tgt_word2idx = word2idx
    #11.用示例做的时候未用到10中,从这里开始真正构建词表
        #训练集词表,标签集词表
        else:
            print('[Info] Build vocabulary for source.')
            src_word2idx = build_vocab_idx(train_src_word_insts, opt.min_word_count)
            print('[Info] Build vocabulary for target.')
            tgt_word2idx = build_vocab_idx(train_tgt_word_insts, opt.min_word_count)

build_vocab_idx()函数就是用来将词语转化成词表的:原理很简单,就是将刚刚产生的所有的句子列表里面的所有的词给拿出来,并给每一个词一个编号,做成一个字典并返回,返回的字典就是词表,函数代码如下。

def build_vocab_idx(word_insts, min_word_count):
    ''' Trim vocab by number of occurence '''

    full_vocab = set(w for sent in word_insts for w in sent)
    print('[Info] Original Vocabulary size =', len(full_vocab))

    word2idx = {
        Constants.BOS_WORD: Constants.BOS,
        Constants.EOS_WORD: Constants.EOS,
        Constants.PAD_WORD: Constants.PAD,
        Constants.UNK_WORD: Constants.UNK}

    word_count = {w: 0 for w in full_vocab}

    for sent in word_insts:
        for word in sent:
            word_count[word] += 1

    ignored_word_count = 0
    for word, count in word_count.items():
        if word not in word2idx:
            if count > min_word_count:
                word2idx[word] = len(word2idx)
            else:
                ignored_word_count += 1

    print('[Info] Trimmed vocabulary size = {},'.format(len(word2idx)),
          'each with minimum occurrence = {}'.format(min_word_count))
    print("[Info] Ignored word count = {}".format(ignored_word_count))
    return word2idx

main函数中word2idx部分

# word to index
    #12.将数据集的训练集和验证集转化为对应词表的下标
    print('[Info] Convert source word instances into sequences of word index.')
    train_src_insts = convert_instance_to_idx_seq(train_src_word_insts, src_word2idx)
    valid_src_insts = convert_instance_to_idx_seq(valid_src_word_insts, src_word2idx)
    
    #将标签集的训练集和验证集转化为对应词表的下标
    print('[Info] Convert target word instances into sequences of word index.')
    train_tgt_insts = convert_instance_to_idx_seq(train_tgt_word_insts, tgt_word2idx)
    valid_tgt_insts = convert_instance_to_idx_seq(valid_tgt_word_insts, tgt_word2idx)

补充:将每一个训练集里面出现过的单词转化为词表里面的一个下标index,并将原本是词语序列构成的句子转化为以词语在词表中的下标序列构成的列表。例如:喜欢=1,他=2,!=3,那么原本的句子 [喜欢,他,!]就变成[1, 2, 3] 实现这个功能的函数就是convert_instance_to_idx_seq,它的返回值就是上述的这个列表,代码如下。

def convert_instance_to_idx_seq(word_insts, word2idx):
    ''' Mapping words to idx sequence. '''
    return [[word2idx.get(w, Constants.UNK) for w in s] for s in word_insts]

构建数据集字典对象

#13.构建一个数据集的字典对象
    data = {
        'settings': opt,#传入参数
        'dict': {       #传入词表
            'src': src_word2idx,
            'tgt': tgt_word2idx},
        'train': {      #传入训练集
            'src': train_src_insts,
            'tgt': train_tgt_insts},
        'valid': {      #传入验证集
            'src': valid_src_insts,
            'tgt': valid_tgt_insts}}

    print('[Info] Dumping the processed data to pickle file', opt.save_data)
    torch.save(data, opt.save_data)#持久化这个字典对象,方便以后调用这个数据集进行训练和测试。
    print('[Info] Finish.')

最后别忘了介个嘻嘻嘻

if __name__ == '__main__':
    main()

我是在服务器上跑的,下载数据:

mkdir -p data/multi30k
wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz &&  tar -xf training.tar.gz -C data/multi30k && rm training.tar.gz
wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz && tar -xf validation.tar.gz -C data/multi30k && rm validation.tar.gz
wget http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz && tar -xf mmt16_task1_test.tar.gz -C data/multi30k && rm mmt16_task1_test.tar.gz

下载数据运行结果
运行process部分

python preprocess.py -train_src data/multi30k/train.en -train_tgt data/multi30k/train.de -valid_src data/multi30k/val.en -valid_tgt data/multi30k/val.de -save_data data/multi30k.atok.low.pt

process部分运行结果

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值