transformer xl---vocabulary

data_dir 存放原始数据,

def main(unused_argv):
    del unused_argv  # Unused

    corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset)  #

    save_dir = os.path.join(FLAGS.data_dir, "tfrecords")
    if not tf.gfile.Exists(save_dir):
        tf.gfile.MakeDirs(save_dir)

    # test mode
    if FLAGS.per_host_test_bsz > 0:
        corpus.convert_to_tfrecords("test", save_dir, FLAGS.per_host_test_bsz,
                                    FLAGS.tgt_len, FLAGS.num_core_per_host,
                                    FLAGS=FLAGS)
        return

    for split, batch_size in zip(
            ["train", "valid"],
            [FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]):

        if batch_size <= 0: continue
        print("Converting {} set...".format(split))
        corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len,
                                    FLAGS.num_core_per_host, FLAGS=FLAGS)

读取字典,字典会使用pickle序列化存储在磁盘中。初次获取字典时,会创建
创建Corpus主要有四步:

1、count_file,读取原文中每一行内容,去除首尾的空格和换行\n,然后逐字拆分为数组,数组中添加< eos >标记,统计每一个词的出现次数记录在counter = Counter(),
2、使用build_vocab创建词汇表,把统计的所有词根据asic编码排序,去除低频词汇
3、add_symbol,原始符号与索引的映射–sym2idx,索引到原始词缀的映射idx2sym(按照顺序,数组下标既是索引)

def get_lm_corpus(data_dir, dataset):
    fn = os.path.join(data_dir, "cache.pkl")

    if tf.gfile.Exists(fn):
        print("Loading cached dataset...")
        with open(fn, "rb") as fp:
            corpus = pickle.load(fp)
    else:
        print("Producing dataset...")
        kwargs = {}

        kwargs["special"] = ["<eos>"]
        kwargs["lower_case"] = False

        corpus = Corpus(data_dir, dataset, **kwargs)

        print("Saving dataset...")
        with open(fn, "wb") as fp:
            pickle.dump(corpus, fp, protocol=2)

        corpus_info = {
            "vocab_size": len(corpus.vocab),
            "cutoffs": corpus.cutoffs,
            "dataset": corpus.dataset
        }
        with open(os.path.join(data_dir, "corpus-info.json"), "w") as fp:
            json.dump(corpus_info, fp)

    return corpus
class Vocab(object):
    def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
                 delimiter=None, vocab_file=None):
        self.counter = Counter()
        self.special = special
        self.min_freq = min_freq
        self.max_size = max_size
        self.lower_case = lower_case
        self.delimiter = delimiter
        self.vocab_file = vocab_file
        self.idx2sym = []
        self.sym2idx = OrderedDict()           # todo  确定这里有没有问题

        # for zhihu dataset
        # todo delete here when test other datasets
        # self.min_freq = 100
        # self.add_symbol('<UNK>')
        # self.unk_idx = self.get_idx('<UNK>')

    def tokenize(self, line, add_eos=False, add_double_eos=False):
        line = line.strip()
        symbols = list(line)

        if add_double_eos:  # lm1b
            # 确保 在symbol list 中能找
            self.add_symbol('<S>')
            return ['<S>'] + symbols + ['<S>']
        elif add_eos:
            return symbols + ['<eos>']
        else:
            return symbols

    # 取出file 中的sentences
    def count_file(self, path, verbose=False, add_eos=False):
        if verbose: print('counting file {} ...'.format(path))
        assert tf.gfile.Exists(path)

        sents = []
        with open(path, 'r',encoding='UTF-8') as f:
        # 读取每一行的内容
            for idx, line in enumerate(f):
                if verbose and idx > 0 and idx % 500000 == 0:
                    print('  line {}'.format(idx))
                symbols = self.tokenize(line, add_eos=True)
                self.counter.update(symbols)
                sents.append(symbols)

        return sents

    # 更新counter 中的token
    def count_sents(self, sents, verbose=False):
        """
          sents : a list of sentences, each a list of tokenized symbols
        """
        if verbose: print('counting {} sents ...'.format(len(sents)))
        for idx, symbols in enumerate(sents):
            if verbose and idx > 0 and idx % 500000 == 0:
                print('  line {}'.format(idx))
            self.counter.update(symbols)

    def _build_from_file(self, vocab_file):
        # self.idx2sym = []
        # self.sym2idx = OrderedDict()

        with open(vocab_file, 'r') as f:
            for line in f:
                symb = line.strip().split()[0]
                self.add_symbol(symb)
        self.unk_idx = self.sym2idx['<UNK>']

    # 建立vocab, 将symbol 保存
    def build_vocab(self):
        if self.vocab_file:
            print('building vocab from {}'.format(self.vocab_file))
            self._build_from_file(self.vocab_file)
            print('final vocab size {}'.format(len(self)))
        else:
            print('building vocab with min_freq={}, max_size={}'.format(
                self.min_freq, self.max_size))

            self.add_special("<eos>")

            # todo 这里巨坑!!!!!
            # for sym, cnt in self.counter.most_common(self.max_size):
            #     if cnt < self.min_freq:
            #         break
            tmp = sorted(self.counter.items(), key=lambda item:item[0])
            for sym, cnt in tmp:
                if cnt < self.min_freq:
                    continue
                self.add_symbol(sym)
            print('final vocab size {} from {} unique tokens'.format(
                len(self), len(self.counter)))

    # 主要在于convert_to_nparray, 其实也就是将vocab变成idx
    def encode_file(self, path, ordered=False, verbose=False,
                    add_double_eos=False):
        if verbose: print('encoding file {} ...'.format(path))
        assert tf.gfile.Exists(path)
        encoded = []
        with open(path, 'r',encoding="utf-8") as f:
            for idx, line in enumerate(f):
                if verbose and idx > 0 and idx % 500000 == 0:
                    print('  line {}'.format(idx))
                symbols = self.tokenize(line, add_eos=True, add_double_eos=add_double_eos)

                encoded.append(self.convert_to_nparray(symbols))

        if ordered:
            encoded = np.concatenate(encoded)

        return encoded

    #
    def encode_sents(self, sents, ordered=False, verbose=False):
        if verbose: print('encoding {} sents ...'.format(len(sents)))
        encoded = []

        symbols = self.tokenize(sents)
        encoded.append(self.convert_to_nparray(symbols))

        if ordered:
            encoded = np.concatenate(encoded)

        return encoded

    def add_special(self, sym):
        if sym not in self.sym2idx:
            self.idx2sym.append(sym)
            self.sym2idx[sym] = len(self.idx2sym) - 1
            setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])

    def add_symbol(self, sym):
        if sym not in self.sym2idx:
            self.idx2sym.append(sym)
            self.sym2idx[sym] = len(self.idx2sym) - 1

    def get_sym(self, idx):
        assert 0 <= idx < len(self.idx2sym), 'Index {} out of range'.format(idx)
        return self.idx2sym[idx]

    def get_idx(self, sym):
        if sym in self.sym2idx:
            return self.sym2idx[sym]
        else:
            assert hasattr(self, 'unk_idx')
            return self.sym2idx.get(sym, self.unk_idx)

    def get_symbols(self, indices):
        return [self.get_sym(idx) for idx in indices]

    def get_indices(self, symbols):
        return [self.get_idx(sym) for sym in symbols]

    # 字转index
    def convert_to_nparray(self, symbols):
        nparray = np.array(self.get_indices(symbols), dtype=np.int64)
        return nparray

    # index转字
    def convert_to_sent(self, indices, exclude=None):
        if exclude is None:
            return ' '.join([self.get_sym(idx) for idx in indices])
        else:
            return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])

    def __len__(self):
        return len(self.idx2sym)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值