2021SC@SDUSC
文章目录
构建数据集
为训练和评估建立数据集
class build_dataset(Dataset):
""" build datasets for train & eval """
def __init__(self, args, tokenizer, mode):
pretrain_model = (
"bert" if "roberta" not in args.pretrain_model_type else "roberta"
)
# --------------------------------------------------------------------------------------------
self.mode = mode
self.tokenizer = tokenizer
self.examples = cached_examples
self.model_class = args.model_class
self.max_phrase_words = args.max_phrase_words
# --------------------------------------------------------------------------------------------
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
return feature_converter[self.model_class](
index,
self.examples[index],
self.tokenizer,
self.mode,
self.max_phrase_words,
)
尝试重新加载缓存的功能
try:
cached_examples = reload_cached_features(
**{
"cached_features_dir": args.cached_features_dir,
"model_class": args.model_class,
"dataset_class": args.dataset_class,
"pretrain_model": pretrain_model,
"mode": mode,
}
)
重新启动预处理功能
except:
logger.info(
"start loading source %s %s data ..." % (args.dataset_class, mode)
)
examples = load_dataset(
os.path.join(
args.preprocess_folder, "%s.%s.json" % (args.dataset_class, mode)
)
)
cached_examples = example_preprocessor[args.model_class](
**{
"examples": examples,
"tokenizer": tokenizer,
"max_token": args.max_token,
"pretrain_model": pretrain_model,
"mode": mode,
"max_phrase_words": args.max_phrase_words,
"stem_flag": True if args.dataset_class == "kp20k" else False,
}
)
if args.local_rank in [-1, 0]:
save_cached_features(
**{
"cached_examples": cached_examples,
"cached_features_dir": args.cached_features_dir,
"model_class": args.model_class,
"dataset_class": args.dataset_class,
"pretrain_model": pretrain_model,
"mode": mode,
}
)
相关方法实现
tokenize_for_bert(doc_words, tokenizer)
预先训练的模型标记化
def tokenize_for_bert(doc_words, tokenizer):
valid_mask = []
all_doc_tokens = []
tok_to_orig_index = []
orig_to_tok_index = []
for (i, token) in enumerate(doc_words):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token)
if len(sub_tokens) < 1:
sub_tokens = [UNK_WORD]
for num, sub_token in enumerate(sub_tokens):
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
if num == 0:
valid_mask.append(1)
else:
valid_mask.append(0)
return {
"tokens": all_doc_tokens,
"valid_mask": valid_mask,
"tok_to_orig_index": tok_to_orig_index,
"orig_to_tok_index": orig_to_tok_index,
}
reload_cached_features
加载缓存的功能:
def reload_cached_features(
cached_features_dir, model_class, dataset_class, pretrain_model, mode
):
logger.info(
"start reloading: %s (%s) for %s (%s) cached features ..."
% (model_class, pretrain_model, dataset_class, mode)
)
filename = os.path.join(
cached_features_dir,
"cached.%s.%s.%s.%s.json" % (model_class, pretrain_model, dataset_class, mode),
)
examples = load_dataset(filename)
return examples
保存缓存的功能:
def save_cached_features(
cached_examples,
cached_features_dir,
model_class,
dataset_class,
pretrain_model,
mode,
):
logger.info(
"start saving: %s (%s) for %s (%s) cached features ..."
% (model_class, pretrain_model, dataset_class, mode)
)
if not os.path.exists(cached_features_dir):
os.mkdir(cached_features_dir)
save_filename = os.path.join(
cached_features_dir,
"cached.%s.%s.%s.%s.json" % (model_class, pretrain_model, dataset_class, mode),
)
save_dataset(data_list=cached_examples, filename=save_filename)
转换标签的功能
def flat_rank_pos(start_end_pos):
flatten_postions = [pos for poses in start_end_pos for pos in poses]
sorted_positions = sorted(flatten_postions, key=lambda x: x[0])
return sorted_positions
删除重叠的关键短语位置
def strict_filter_overlap(positions):
"""delete overlap keyphrase positions. """
previous_e = -1
filter_positions = []
for i, (s, e) in enumerate(positions):
if s <= previous_e:
continue
filter_positions.append(positions[i])
previous_e = e
return filter_positions
def loose_filter_overlap(positions):
"""delete overlap keyphrase positions. """
previous_s = -1
filter_positions = []
for i, (s, e) in enumerate(positions):
if previous_s == s:
continue
elif previous_s < s:
filter_positions.append(positions[i])
previous_s = s
else:
logger.info("Error! previous start large than new start")
return filter_positions
限制短语的长度
def limit_phrase_length(positions, max_phrase_words):
filter_positions = [
pos for pos in positions if (pos[1] - pos[0] + 1) <= max_phrase_words
]
return filter_positions
删除超范围关键字短语位置(标记长度>510)和短语长度>5
def limit_scope_length(start_end_pos, valid_length, max_phrase_words):
"""filter out positions over scope & phase_length > 5"""
filter_positions = []
for positions in start_end_pos:
_filter_position = [
pos
for pos in positions
if pos[1] < valid_length and (pos[1] - pos[0] + 1) <= max_phrase_words
]
if len(_filter_position) > 0:
filter_positions.append(_filter_position)
return filter_positions
提取主干stem
def stemming(phrase):
norm_chars = unicodedata.normalize("NFD", phrase)
stem_chars = " ".join([stemmer.stem(w) for w in norm_chars.split(" ")])
return norm_chars, stem_chars
判断unicoding(gram) and stemming(gram)是否在phase中
def whether_stem_existing(gram, phrase2index, tot_phrase_list):
"""If :
unicoding(gram) and stemming(gram) not in phrase2index,
Return : not_exist_flag
Else :
Return : index already in phrase2index.
"""
norm_gram, stem_gram = stemming(gram)
if norm_gram in phrase2index:
index = phrase2index[norm_gram]
phrase2index[stem_gram] = index
return index
elif stem_gram in phrase2index:
index = phrase2index[stem_gram]
phrase2index[norm_gram] = index
return index
else:
index = len(tot_phrase_list)
phrase2index[norm_gram] = index
phrase2index[stem_gram] = index
tot_phrase_list.append(gram)
return index
判断gram 是否存在 在phrase2index
def whether_existing(gram, phrase2index, tot_phrase_list):
"""If :
gram not in phrase2index,
Return : not_exist_flag
Else :
Return : index already in phrase2index.
"""
if gram in phrase2index:
index = phrase2index[gram]
return index
else:
index = len(tot_phrase_list)
phrase2index[gram] = index
tot_phrase_list.append(gram)
return index
补充知识学习
word2Vec
论文阅读:
分布式语义
主要思想-word2Vec
skip-gram 模型
- 公式分析
训练模型-梯度下降
模型流程 - skip-gram为例
Hierarchical Softmax
Negative Sampling(负采样)