from __future__ import division, print_function, absolute_import
import collections
import os
import random
import urllib
import zipfile
import numpy as np
import tensorflow as tf
learning_rate = 0.1
batch_size = 128
num_steps = 3000000
display_step = 10000
eval_step = 200000
# 训练参数
learning_rate = 0.1
batch_size = 128
num_steps = 3000000
display_step = 10000
eval_step = 200000
# 评估参数
eval_words = ['five', 'of', 'going', 'hardware', 'american', 'britain']
# Word2Vec 参数
embedding_size = 200 # 嵌入向量的维度 vector.
max_vocabulary_size = 50000 # 词汇表中不同单词的总数words in the vocabulary.
min_occurrence = 10 # 删除出现小于n次的所有单词
skip_window = 3 # 左右各要考虑多少个单词
num_skips = 2 # 重复使用输入生成标签的次数
num_sampled = 64 # 负采样数量
# 下载一小部分维基百科文章集
url = 'http://mattmahoney.net/dc/text8.zip'
data_path = 'text8.zip'
if not os.path.exists(data_path):
print("Downloading the dataset... (It may take some time)")
filename, _ = urllib.request.urlretrieve(url, data_path)
print("Done!")
# 解压数据集文件,文本已处理完毕
with zipfile.ZipFile(data_path) as f:
text_words = f.read(f.namelist()[0]).lower().split()
def read_data(filename):
"""Extract the first file enclosed in a zip file as a list of words"""
with zipfile.ZipFile(filename) as f:
data = tf.compat.as_str(f.read(f.namelist()[0])).split()
return data
打印data看一下,
print(words[0:100])
['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against', 'early', 'working', 'class', 'radicals', 'including', 'the', 'diggers', 'of', 'the', 'english', 'revolution', 'and', 'the', 'sans', 'culottes', 'of', 'the', 'french', 'revolution', 'whilst', 'the', 'term', 'is', 'still', 'used', 'in', 'a', 'pejorative', 'way', 'to', 'describe', 'any', 'act', 'that', 'used', 'violent', 'means', 'to', 'destroy', 'the', 'organization', 'of', 'society', 'it', 'has', 'also', 'been', 'taken', 'up', 'as', 'a', 'positive', 'label', 'by', 'self', 'defined', 'anarchists', 'the', 'word', 'anarchism', 'is', 'derived', 'from', 'the', 'greek', 'without', 'archons', 'ruler', 'chief', 'king', 'anarchism', 'as', 'a', 'political', 'philosophy', 'is', 'the', 'belief', 'that', 'rulers', 'are', 'unnecessary', 'and', 'should', 'be', 'abolished', 'although', 'there', 'are', 'differing']
# 构建词典并用 UNK 标记替换频数较低的词
count = [('UNK', -1)]
# 检索最常见的单词
count.extend(collections.Counter(text_words).most_common(max_vocabulary_size - 1))
# 删除少于'min_occurrence'次数的样本
for i in range(len(count) - 1, -1, -1):
if count[i][1] < min_occurrence:
count.pop(i)
else:
#该集合是有序的,因此在当出现小于'min_occurrence'时停止
break
# 计算单词表单词个数
vocabulary_size = len(count)
# 为每一个词分配id
word2id = dict()
for i, (word, _)in enumerate(count):
word2id[word] = i
data = list()
unk_count = 0
for word in text_words:
# 检索单词id,或者如果不在字典中则为其指定索引0('UNK')
index = word2id.get(word, 0)
if index == 0:
unk_count += 1
data.append(index)
count[0] = ('UNK', unk_count)
id2word = dict(zip(word2id.values(), word2id.keys()))
print("Words count:", len(text_words))
print("Unique words:", len(set(text_words)))
print("Vocabulary size:", vocabulary_size)
print("Most common words:", count[:10])
得到
Words count: 17005207
Unique words: 253854
Vocabulary size: 47135
Most common words: [('UNK', 444176), (b'the', 1061396), (b'of', 593677), (b'and', 416629), (b'one', 411764), (b'in', 372201), (b'a', 325873), (b'to', 316376), (b'zero', 264975), (b'nine', 250430)]
data_index = 0
# 为skip-gram模型生成训练批次
def next_batch(batch_size, num_skips, skip_window):
global data_index
assert batch_size % num_skips == 0
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=(batch_size), dtype=np.int32)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
# 得到窗口长度( 当前单词左边和右边 + 当前单词)
span = 2 * skip_window + 1
buffer = collections.deque(maxlen=span)
if data_index + span > len(data):
data_index = 0
buffer.extend(data[data_index:data_index + span])
data_index += span
for i in range(batch_size // num_skips):
context_words = [w for w in range(span) if w != skip_window]
words_to_use = random.sample(context_words, num_skips) # sample
for j, context_word in enumerate(words_to_use):
batch[i * num_skips + j] = buffer[skip_window]
labels[i * num_skips + j, 0] = buffer[context_word]
if data_index == len(data):
buffer.extend(data[0:span])
data_index = span
else:
buffer.append(data[data_index])
data_index += 1
#回溯一点,以避免在批处理结束时跳过单词
data_index = (data_index + len(data) - span) % len(data)
return batch, labels
# 确保在CPU上分配以下操作和变量
# (某些操作在GPU上不兼容)
with tf.device('/cpu:0'):
# 创建嵌入变量(每一行代表一个词嵌入向量) embedding vector).
embedding = tf.Variable(tf.random.normal([vocabulary_size, embedding_size]))
# 构造NCE损失的变量
nce_weights = tf.Variable(tf.random.normal([vocabulary_size, embedding_size]))
nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
def get_embedding(x):
with tf.device('/cpu:0'):
# 对于X中的每一个样本查找对应的嵌入向量
x_embed = tf.nn.embedding_lookup(embedding, x)
return x_embed
def nce_loss(x_embed, y):
with tf.device('/cpu:0'):
# 计算批处理的平均NCE损失
y = tf.cast(y, tf.int64)
loss = tf.reduce_mean(
tf.nn.nce_loss(weights=nce_weights,
biases=nce_biases,
labels=y,
inputs=x_embed,
num_sampled=num_sampled,
num_classes=vocabulary_size))
return loss
# 评估
def evaluate(x_embed):
with tf.device('/cpu:0'):
# 计算输入数据嵌入与每个嵌入向量之间的余弦相似度
x_embed = tf.cast(x_embed, tf.float32)
x_embed_norm = x_embed / tf.sqrt(tf.reduce_sum(tf.square(x_embed)))
embedding_norm = embedding / tf.sqrt(tf.reduce_sum(tf.square(embedding), 1, keepdims=True), tf.float32)
cosine_sim_op = tf.matmul(x_embed_norm, embedding_norm, transpose_b=True)
return cosine_sim_op
# 定义优化器
optimizer = tf.optimizers.SGD(learning_rate)
# 优化过程
def run_optimization(x, y):
with tf.device('/cpu:0'):
# 将计算封装在GradientTape中以实现自动微分
with tf.GradientTape() as g:
emb = get_embedding(x)
loss = nce_loss(emb, y)
# 计算梯度
gradients = g.gradient(loss, [embedding, nce_weights, nce_biases])
# 按gradients更新 W 和 b
optimizer.apply_gradients(zip(gradients, [embedding, nce_weights, nce_biases]))
# 针对给定步骤数进行训练
for step in range(1, num_steps + 1):
batch_x, batch_y = next_batch(batch_size, num_skips, skip_window)
run_optimization(batch_x, batch_y)
# print("step:", step)
if step % display_step == 0 or step == 1:
loss = nce_loss(get_embedding(batch_x), batch_y)
print("step: %i, loss: %f" % (step, loss))
# 评估
if step % eval_step == 0 or step == 1:
print("Evaluation...")
sim = evaluate(get_embedding(x_test)).numpy()
for i in range(len(eval_words)):
top_k = 8 # 最相似的单词数量
nearest = (-sim[i, :]).argsort()[1:top_k + 1]
log_str = '"%s" nearest neighbors:' % eval_words[i]
for k in range(top_k):
log_str = '%s %s,' % (log_str, id2word[nearest[k]])
print(log_str)
打印结果如下:
step: 1, loss: 44.536224
Evaluation...
"five" nearest neighbors: b'eight', b'four', b'three', b'two', b'six', b'one', b'zero', b'nine',
"of" nearest neighbors: b'and', b'a', b'for', b'in', b'was', b'to', b'the', b'with',
"going" nearest neighbors: b'himself', b'generally', b'name', b'day', b'main', b'original', b'alexander', b'area',
"hardware" nearest neighbors: b'high', b'player', b'well', b'work', b'non', b'end', b'continued', b'sea',
"american" nearest neighbors: b's', b'on', b'in', b'from', b'by', b'some', b'his', b'at',
"britain" nearest neighbors: b'd', b'century', b'south', b'well', b'force', b'history', b'those', b'later',
step: 10000, loss: 17.617069
step: 20000, loss: 32.009899
step: 30000, loss: 39.825680
step: 40000, loss: 46.564484
step: 50000, loss: 19.030632
step: 60000, loss: 33.419670
step: 70000, loss: 22.269041
step: 80000, loss: 24.664249
step: 90000, loss: 28.758041
step: 100000, loss: 18.576723
step: 110000, loss: 17.187302
step: 120000, loss: 20.497646
step: 130000, loss: 15.909262
step: 140000, loss: 18.391491
step: 150000, loss: 7.722824
step: 160000, loss: 9.170986
step: 170000, loss: 18.167021
step: 180000, loss: 17.995640
step: 190000, loss: 9.873861
step: 200000, loss: 5.985056
Evaluation...
"five" nearest neighbors: b'six', b'four', b'three', b'seven', b'eight', b'two', b'nine', b'one',
"of" nearest neighbors: b'and', b'for', b'the', b'were', b'by', b'a', b'however', b'with',
"going" nearest neighbors: b'himself', UNK, b'alexander', b'are', b'd', b'by', b'while', b'and',
"hardware" nearest neighbors: b'continued', b'work', b'not', b'defense', b'apollo', b'with', b'made', b'which',
"american" nearest neighbors: b'b', b'd', b's', b'in', b'after', b'english', UNK, b'about',
"britain" nearest neighbors: UNK, b'many', b'in', b'the', b'by', b'with', b'see', b'all',
step: 210000, loss: 19.672812
step: 220000, loss: 17.169825
step: 230000, loss: 5.966599
step: 240000, loss: 13.860767
step: 250000, loss: 11.289427
step: 260000, loss: 11.314636
step: 270000, loss: 12.816988
step: 280000, loss: 8.976292
step: 290000, loss: 9.926281
step: 300000, loss: 17.872997
step: 310000, loss: 9.930416
step: 320000, loss: 13.454721
step: 330000, loss: 7.589732
step: 340000, loss: 12.779041
step: 350000, loss: 13.483349
step: 360000, loss: 8.728107
step: 370000, loss: 11.296732
step: 380000, loss: 6.875031
step: 390000, loss: 7.842208
step: 400000, loss: 13.643340
Evaluation...
"five" nearest neighbors: b'four', b'three', b'six', b'two', b'seven', b'eight', b'one', b'zero',
"of" nearest neighbors: b'in', b'the', b'and', b'while', b'a', b'from', b'for', b'which',
"going" nearest neighbors: UNK, b'e', b'french', b'include', b'and', b'government', b'where', b'see',
"hardware" nearest neighbors: b'high', b'non', b'include', b'were', b'theory', b'use', b'large', b'or',
"american" nearest neighbors: b'in', b's', b'british', b'john', b'after', b'since', UNK, b'from',
"britain" nearest neighbors: b'south', b'first', b'great', b'modern', b'large', b'at', b'history', b'all',
step: 410000, loss: 6.513102
step: 420000, loss: 5.759387
step: 430000, loss: 6.949007
step: 440000, loss: 10.842569
step: 450000, loss: 7.122824
step: 460000, loss: 7.226061
step: 470000, loss: 12.767937
step: 480000, loss: 11.831079
step: 490000, loss: 7.544044
step: 500000, loss: 9.842602
...
稍微看下batch_x的嵌入结果
get_embedding(batch_x), embedding.numpy().shape
(<tf.Tensor: shape=(128, 200), dtype=float32, numpy=
array([[ 1.4036576 , 0.6463658 , -0.63774115, ..., -0.02543012,
1.1143191 , 1.6757603 ],
[ 1.4036576 , 0.6463658 , -0.63774115, ..., -0.02543012,
1.1143191 , 1.6757603 ],
[ 1.0855132 , -0.20949575, -0.3836846 , ..., 0.36046728,
0.613291 , 0.7449575 ],
...,
[ 1.7644083 , 0.3866808 , -0.04017097, ..., 1.3998463 ,
-0.5211887 , 0.44882432],
[ 1.0006328 , -0.09737214, -0.36121845, ..., 0.31400985,
0.6167378 , 0.7291899 ],
[ 1.0006328 , -0.09737214, -0.36121845, ..., 0.31400985,
0.6167378 , 0.7291899 ]], dtype=float32)>,
(47135, 200))