一、随机初始化Embedding
1.1 原理
Embedding其实就是个lookup table, 通过tf.nn.embedding_lookup()来调用Embedding.
注意:在调用Embedding后,可以考虑使用dropout层。
注意:在Embedding内,可以考虑对提取的vector做缩放。见于《Attention is all you need》
1.2 示例代码
这是关于Embedding层的相关代码。
def embedding(inputs, vocab_size, num_units, pre_embed=None, scale=True,
scope="embedding", reuse=None):
with tf.variable_scope(name_or_scope=scope, reuse=reuse):
if pre_embed is not None:
# 如果指定了词表.
lookup_table = tf.get_variable(name="lookup_table",
initializer=tf.Variable(pre_embed, dtype=tf.float32),
trainable=False)
else:
lookup_table = tf.get_variable(name="lookup_table",
shape=[vocab_size, num_units],
initializer=tf.contrib.layers.xavier_initializer())
outputs = tf.nn.embedding_lookup(lookup_table, inputs)
if scale:
# attention is all you need. 中使用了.
outputs = outputs * (num_units ** 0.5)
return outputs
下面的代码是:
- load_my_vocab:
加载模型中使用的vocab词表,用于在预训练的Embedding中,查找对应向量。 - get_vocab_embedding:
从预训练的Embedding中,提取需要的向量,并返回。
def load_my_vocab(vocab_file):
vocab_list = []
with open(vocab_file, 'r', encoding='utf-8') as fr:
line = fr.readline()
while line:
vocab_list.append(line.strip())
line = fr.readline()
return vocab_list
def get_vocab_embedding(vocab_list, pre_embedding_file):
lookup_dict = {}
for vocab in vocab_list:
lookup_dict[vocab] = []
with open(pre_embedding_file, 'r', encoding='utf-8') as fr:
line = fr.readline().strip("\n")
while line:
vocab, embed = line.split(" ", 1)
if vocab in lookup_dict:
embed = [float(digit_str) for digit_str in embed.strip().split()]
lookup_dict[vocab].append(embed)
line = fr.readline().strip("\n")
lookup_table = [np.mean(lookup_dict[vocab], axis=0) for vocab in vocab_list]
return np.array(lookup_table)
下面是主程序:
if __name__ == "__main__":
vocab_file = "./vocab.txt"
pre_embedding_file = "./glove.840B.300d.txt"
vocab_list = load_my_vocab(vocab_file)
lookup_table = get_vocab_embedding(vocab_list, pre_embedding_file)
vocab_size = 5
inputs = tf.Variable([[1, 2, 3], [2, 2, 2], [0, 0, 1]], dtype=tf.int32)
word_embed = embedding(inputs=inputs, vocab_size=5, num_units=300, pre_embed=lookup_table)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
outputs = sess.run(word_embed)
print(outputs)