loader_utils.py + word2Vec补充学习

2021SC@SDUSC


构建数据集

为训练和评估建立数据集

class build_dataset(Dataset):
    """ build datasets for train & eval """
    def __init__(self, args, tokenizer, mode):
        pretrain_model = (
            "bert" if "roberta" not in args.pretrain_model_type else "roberta"
        )
       # --------------------------------------------------------------------------------------------
        self.mode = mode
        self.tokenizer = tokenizer
        self.examples = cached_examples
        self.model_class = args.model_class
        self.max_phrase_words = args.max_phrase_words
        # --------------------------------------------------------------------------------------------

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        return feature_converter[self.model_class](
            index,
            self.examples[index],
            self.tokenizer,
            self.mode,
            self.max_phrase_words,
        )

尝试重新加载缓存的功能

try:
            cached_examples = reload_cached_features(
                **{
                    "cached_features_dir": args.cached_features_dir,
                    "model_class": args.model_class,
                    "dataset_class": args.dataset_class,
                    "pretrain_model": pretrain_model,
                    "mode": mode,
                }
            )

重新启动预处理功能

except:
            logger.info(
                "start loading source %s %s data ..." % (args.dataset_class, mode)
            )
            examples = load_dataset(
                os.path.join(
                    args.preprocess_folder, "%s.%s.json" % (args.dataset_class, mode)
                )
            )
            cached_examples = example_preprocessor[args.model_class](
                **{
                    "examples": examples,
                    "tokenizer": tokenizer,
                    "max_token": args.max_token,
                    "pretrain_model": pretrain_model,
                    "mode": mode,
                    "max_phrase_words": args.max_phrase_words,
                    "stem_flag": True if args.dataset_class == "kp20k" else False,
                }
            )
            if args.local_rank in [-1, 0]:
                save_cached_features(
                    **{
                        "cached_examples": cached_examples,
                        "cached_features_dir": args.cached_features_dir,
                        "model_class": args.model_class,
                        "dataset_class": args.dataset_class,
                        "pretrain_model": pretrain_model,
                        "mode": mode,
                    }
                )

相关方法实现

tokenize_for_bert(doc_words, tokenizer)

预先训练的模型标记化

def tokenize_for_bert(doc_words, tokenizer):
    valid_mask = []
    all_doc_tokens = []
    tok_to_orig_index = []
    orig_to_tok_index = []
    for (i, token) in enumerate(doc_words):
        orig_to_tok_index.append(len(all_doc_tokens))
        sub_tokens = tokenizer.tokenize(token)
        if len(sub_tokens) < 1:
            sub_tokens = [UNK_WORD]
        for num, sub_token in enumerate(sub_tokens):
            tok_to_orig_index.append(i)
            all_doc_tokens.append(sub_token)
            if num == 0:
                valid_mask.append(1)
            else:
                valid_mask.append(0)
    return {
        "tokens": all_doc_tokens,
        "valid_mask": valid_mask,
        "tok_to_orig_index": tok_to_orig_index,
        "orig_to_tok_index": orig_to_tok_index,
    }

reload_cached_features

加载缓存的功能:

def reload_cached_features(
    cached_features_dir, model_class, dataset_class, pretrain_model, mode
):
    logger.info(
        "start reloading:  %s (%s) for %s (%s) cached features ..."
        % (model_class, pretrain_model, dataset_class, mode)
    )
    filename = os.path.join(
        cached_features_dir,
        "cached.%s.%s.%s.%s.json" % (model_class, pretrain_model, dataset_class, mode),
    )
    examples = load_dataset(filename)
    return examples

保存缓存的功能:

def save_cached_features(
    cached_examples,
    cached_features_dir,
    model_class,
    dataset_class,
    pretrain_model,
    mode,
):
    logger.info(
        "start saving:  %s (%s) for %s (%s) cached features ..."
        % (model_class, pretrain_model, dataset_class, mode)
    )
    if not os.path.exists(cached_features_dir):
        os.mkdir(cached_features_dir)
    save_filename = os.path.join(
        cached_features_dir,
        "cached.%s.%s.%s.%s.json" % (model_class, pretrain_model, dataset_class, mode),
    )
    save_dataset(data_list=cached_examples, filename=save_filename)

转换标签的功能

def flat_rank_pos(start_end_pos):
    flatten_postions = [pos for poses in start_end_pos for pos in poses]
    sorted_positions = sorted(flatten_postions, key=lambda x: x[0])
    return sorted_positions

删除重叠的关键短语位置

def strict_filter_overlap(positions):
    """delete overlap keyphrase positions. """
    previous_e = -1
    filter_positions = []
    for i, (s, e) in enumerate(positions):
        if s <= previous_e:
            continue
        filter_positions.append(positions[i])
        previous_e = e
    return filter_positions

def loose_filter_overlap(positions):
    """delete overlap keyphrase positions. """
    previous_s = -1
    filter_positions = []
    for i, (s, e) in enumerate(positions):
        if previous_s == s:
            continue
        elif previous_s < s:
            filter_positions.append(positions[i])
            previous_s = s
        else:
            logger.info("Error! previous start large than new start")
    return filter_positions

限制短语的长度

def limit_phrase_length(positions, max_phrase_words):
    filter_positions = [
        pos for pos in positions if (pos[1] - pos[0] + 1) <= max_phrase_words
    ]
    return filter_positions

删除超范围关键字短语位置(标记长度>510)和短语长度>5

def limit_scope_length(start_end_pos, valid_length, max_phrase_words):
    """filter out positions over scope & phase_length > 5"""
    filter_positions = []
    for positions in start_end_pos:
        _filter_position = [
            pos
            for pos in positions
            if pos[1] < valid_length and (pos[1] - pos[0] + 1) <= max_phrase_words
        ]
        if len(_filter_position) > 0:
            filter_positions.append(_filter_position)
    return filter_positions

提取主干stem

def stemming(phrase):
    norm_chars = unicodedata.normalize("NFD", phrase)
    stem_chars = " ".join([stemmer.stem(w) for w in norm_chars.split(" ")])
    return norm_chars, stem_chars

判断unicoding(gram) and stemming(gram)是否在phase中

def whether_stem_existing(gram, phrase2index, tot_phrase_list):
    """If :
    unicoding(gram) and stemming(gram) not in phrase2index,
    Return : not_exist_flag
    Else :
    Return : index already in phrase2index.
    """
    norm_gram, stem_gram = stemming(gram)
    if norm_gram in phrase2index:
        index = phrase2index[norm_gram]
        phrase2index[stem_gram] = index
        return index

    elif stem_gram in phrase2index:
        index = phrase2index[stem_gram]
        phrase2index[norm_gram] = index
        return index

    else:
        index = len(tot_phrase_list)
        phrase2index[norm_gram] = index
        phrase2index[stem_gram] = index
        tot_phrase_list.append(gram)
        return index

判断gram 是否存在 在phrase2index

def whether_existing(gram, phrase2index, tot_phrase_list):
    """If :
    gram not in phrase2index,
    Return : not_exist_flag
    Else :
    Return : index already in phrase2index.
    """
    if gram in phrase2index:
        index = phrase2index[gram]
        return index
    else:
        index = len(tot_phrase_list)
        phrase2index[gram] = index
        tot_phrase_list.append(gram)
        return index

补充知识学习

word2Vec

论文阅读:
在这里插入图片描述
在这里插入图片描述

分布式语义

在这里插入图片描述

主要思想-word2Vec

在这里插入图片描述

skip-gram 模型

在这里插入图片描述

  • 公式分析

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

训练模型-梯度下降

在这里插入图片描述

模型流程 - skip-gram为例

在这里插入图片描述
在这里插入图片描述

Hierarchical Softmax

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Negative Sampling(负采样)

在这里插入图片描述
在这里插入图片描述

Subsampling(二次采样)

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值