三角函数的位置编码(原始Transformer)
def get_pos_embedding(seq_len, dim):
import math
vectors = [[.0 if i % 2 == 0 else 1. for i in range(dim)],
[math.sin(1. / math.pow(10000., (i / dim))) if i % 2 == 0 else
math.cos(1. / math.pow(10000., ((i - 1) / dim))) for i in range(dim)]]
for i in range(seq_len - 2):
pre, k1 = vectors[-1], vectors[1]
row = [pre[i] * k1[i + 1] + pre[i + 1] * k1[i] if i % 2 == 0 else
pre[i] * k1[i] - pre[i - 1] * k1[i - 1] for i in range(dim)]
vectors.append(row)
return vectors
获取文本词表映射
def get_word2index(texts, min_count=0, max_count=1e10, stopwords={}):
def is_valid(w):
return True
from collections import Counter
word_count = sorted(list(Counter([w for words in texts for w in words]).items()), key=lambda x: x[1])
word_count = [[w, c] for w, c in word_count if
is_valid(w) and
min_count <= c <= max_count and
w not in stopwords
]
word2index = {w: i for i, (w, _) in word_count}
return word2index
清洗texts
def clean_texts_(texts, word2index):
def padding_seq(seq, pad_idx, pad_len):
if len(seq) > pad_len: return seq[:pad_len]
return seq + [pad_idx] * (pad_len - len(seq))
PAD_CHAR = "<PAD>"
texts_clean = [[w for w in words if w in word2index] for words in texts]
seq_len = sorted([len(e) for e in texts_clean])[int(0.95 * len(texts_clean))]
texts_clean = [padding_seq(e, PAD_CHAR, seq_len) for e in texts]
return texts_clean
获取golve-word2vec
def get_word2vec(word2index: dict, embedding_dim=300):
from config import TEMP_PATH
import gensim
google = gensim.models.KeyedVectors.load_word2vec_format(TEMP_PATH + r"\origin-pretrain-golve.bin", binary=True)
google_set = set(list(google.vocab))
index2vec = {i: google[w] if w in google_set else list(np.random.random(embedding_dim))
for w, i in word2index.items()}
word2vec = [index2vec[i] for i in range(len(index2vec))] + [[0.] * embedding_dim]
return word2vec
序列补全
def padding_seq(seq, pad_idx, pad_len):
if len(seq) > pad_len: return seq[:pad_len]
return seq + [pad_idx] * (pad_len - len(seq))
显示长度分布信息
def show_lens_info(data):
lens = sorted([len(e) for e in data])
print("max", max(lens))
print("min", min(lens))
print("avg", sum(lens) / len(lens))
infos = [[e, lens[int(len(lens) * e)]] for e in np.linspace(0.8, 1 - 0.02, 10)]
for e, res in infos: print(f"{int(e * 100)}%", res)
distplot([e for e in lens if e < 3000])
pyplot.show()
res = [res for e, res in infos] + [max(lens), min(lens), eval("{:.2f}".format(sum(lens) / len(lens)))]
return res