word2vec原理(五):skip-gram和CBOW模型代码实现

目录

代码一

代码二


     第一部分代码对于skip-gram和CBOW模型是通用的,第二部分是实现skip-gram模型的代码。

代码一:

import os
from six.moves.urllib.request import urlretrieve
import zipfile
import collections


# http://mattmahoney.net/dc/textdata.html
dataset_link = 'http://mattmahoney.net/dc/'
zip_file = 'text8.zip'

# 查看下载进度
def cbk(a,b,c):
    '''回调函数
    @a:已经下载的数据块
    @b:数据块的大小
    @c:远程文件的大小
    '''
    per = 100.0*a*b/c
    if per > 100:
        per = 100
    print('%.2f%%' % per)

def data_download(zip_file):
    '''下载数据集'''
    if not os.path.exists(zip_file):
        # urlretrieve()方法直接将远程数据下载到本地
        zip_file, _ = urlretrieve(dataset_link + zip_file, zip_file, cbk)
        print('File downloaded successfully!')
    return None


def extracting(extracted_folder, zip_file):
    '''解压缩'''
    if not os.path.isdir(extracted_folder):
        with zipfile.ZipFile(zip_file) as zf:
            # 功能:解压zip文档中的所有文件到当前目录。
            zf.extractall(extracted_folder)


def text_processing(ft8_text):
    # 标点处理
    ft8_text = ft8_text.lower()
    ft8_text = ft8_text.replace('.', ' <period> ')
    ft8_text = ft8_text.replace(',', ' <comma> ')
    ft8_text = ft8_text.replace('"', ' <quotation> ')
    ft8_text = ft8_text.replace(';', ' <semicolon> ')
    ft8_text = ft8_text.replace('!', ' <exclamation> ')
    ft8_text = ft8_text.replace('?', ' <question> ')
    ft8_text = ft8_text.replace('(', ' <paren_l> ')
    ft8_text = ft8_text.replace(')', ' <paren_r> ')
    ft8_text = ft8_text.replace('--', ' <hyphen> ')
    ft8_text = ft8_text.replace(':', ' <colon> ')
    ft8_text_tokens = ft8_text.split()
    return ft8_text_tokens

def remove_lowerfreword(ft_tokens):
    '''去除与单词相关的噪音:输入数据集中词频小于7的单词'''
    word_cnt = collections.Counter(ft_tokens)  # 统计列表元素出现次数,一个无序的容器类型,以字典的键值对形式存储,其中元素作为key,其计数作为value
    shortlisted_words = [w for w in ft_tokens if word_cnt[w]>7]
    print(shortlisted_words[:15])  # 列出数据集中词频最高的15个单词
    print('Total number of shortlisted_words', len(shortlisted_words))  # 16616688
    print('Unique number of shortlisted_words', len(set(shortlisted_words)))  #53721
    return shortlisted_words

def dict_creation(shortlisted_words):
    '''创建词汇表:单词-词频'''
    counts = collections.Counter(shortlisted_words)
    vocabulary = sorted(counts, key=counts.get, reverse=True)
    rev_dictionary = {ii:word for ii,word in enumerate(vocabulary)}  # 整数:单词
    dictionary = {word:ii for ii, word in rev_dictionary.items()}    # 单词:整数
    return dictionary, rev_dictionary
部分库解读:
1. six是用来兼容python2和3的库。
six.moves 是用来处理那些在2和3里面函数的位置有变化的,直接用six.moves就可以屏蔽掉这些变化

2. zipfile.ZipFile(zip_file) 打开压缩文件zip_file
   ZipFile.extractall([path[, members[, pwd]]])  解压zip文档中的所有文件到当前目录。
 参数:
    path        指定解析文件保存的文件夹
    member      指定要解压的文件名称或对应的ZipInfo对象
    pwd         解压密码

代码二:

import collections
import time
import numpy as np
import random
import tensorflow as tf
from text_processing import *
from sklearn.manifold import TSNE

def subsampling(words_cnt):
    # 采用子采样处理文本中的停止词
    thresh = 0.00005
    word_counts = collections.Counter(words_cnt)
    total_count = len(words_cnt)
    freqs = {word: count/total_count for word, count in word_counts.items()}
    p_drop = {word: 1 - np.sqrt(thresh/freqs[word]) for word in word_counts}
    train_words = [word for word in words_cnt if p_drop[word] < random.random()]
    return train_words

def skipG_target_set_generation(batch_, batch_index, word_window):
    # 以所需格式创建skip-gram模型的输入:即中心词周围的词
    random_num = np.random.randint(1, word_window+1)   # 在word_window范围内随机选取周围词的数量
    words_start = batch_index - random_num if (batch_index-random_num) > 0 else 0
    words_stop = batch_index + random_num
    window_target = set(batch_[words_start:batch_index] + batch_[batch_index+1:words_stop+1])
    return list(window_target)

def skipG_batch_creation(short_words,batch_length,word_window):
    # 创建中心词及其周围单词的组合形式
    batch_cnt = len(short_words)//batch_length
    print('batch_cnt=',batch_cnt)
    short_words = short_words[:batch_cnt*batch_length]

    for word_index in range(0, len(short_words), batch_length):
        input_words,label_words = [],[]
        word_batch = short_words[word_index:word_index+batch_length]

        for index_ in range(len(word_batch)):  # 遍历每个batch中的每个中词
            batch_input = word_batch[index_]
            batch_label = skipG_target_set_generation(word_batch, index_, word_window)   # 获取周围单词
            label_words.extend(batch_label)
            input_words.extend([batch_input]*len(batch_label))   # skip_gram的输入形式,周围单词都得对应上中心词
            yield input_words, label_words


# extracted_folder = 'dataset'
# full_text = extracting(extracted_folder, zip_file)

with open('dataset/text8') as ft_:
    full_text = ft_.read()

ft_tokens = text_processing(full_text)   # 单词列表
shortlisted_words = remove_lowerfreword(ft_tokens)
dictionary, rev_dictionary = dict_creation(shortlisted_words)
words_cnt = [dictionary[word] for word in shortlisted_words]   # 通过词典获取每个单词对应的整数
train_words = subsampling(words_cnt)
print('train_words=',len(train_words))

# 1.
tf_graph = tf.Graph()
with tf_graph.as_default():
    input_ = tf.placeholder(tf.int32, [None], name='input_')
    label_ = tf.placeholder(tf.int32, [None, None], name='label_')

# 2. 得到embedding
with tf_graph.as_default():
    word_embed = tf.Variable(tf.random_uniform((len(rev_dictionary), 300),-1,1))
    embedding = tf.nn.embedding_lookup(word_embed, input_)  # 将单词转换为向量

# 3.定义优化算法
vocabulary_size = len(rev_dictionary)
with tf_graph.as_default():
    sf_weights = tf.Variable(tf.truncated_normal((vocabulary_size,300),stddev=0.1))
    sf_bias = tf.Variable(tf.zeros(vocabulary_size))
    # 通过负采样计算loss
    loss_fn = tf.nn.sampled_softmax_loss(weights=sf_weights,
                                         biases=sf_bias,
                                         labels=label_,
                                         inputs=embedding,
                                         num_sampled=100,
                                         num_classes=vocabulary_size)
    cost_fn = tf.reduce_mean(loss_fn)
    optim = tf.train.AdamOptimizer().minimize(cost_fn)

# 4. 验证集:在语料库中选择常见和不常见词的组合,并基于词向量之间的余弦相似性返回最接近它们之间的单词
with tf_graph.as_default():
    validation_cnt = 16
    validation_dict = 100

    validation_words = np.array(random.sample(range(validation_dict), validation_cnt//2))  # 从list(range(validation_dict))中随机获取8个元素,作为一个片断返回
    validation_words = np.append(validation_words, random.sample(range(1000, 1000+validation_dict), validation_cnt//2))
    validation_data = tf.constant(validation_words, dtype=tf.int32)
    normalization_embed = word_embed / (tf.sqrt(tf.reduce_sum(tf.square(word_embed),1,keep_dims=True)))
    validation_embed = tf.nn.embedding_lookup(normalization_embed, validation_data)
    word_similarity = tf.matmul(validation_embed,tf.transpose(normalization_embed))

epochs = 2
batch_length = 1000
word_window = 10

# 定义模型存储检查点model_checkpoint
with tf_graph.as_default():
    saver = tf.train.Saver()

with tf.Session(graph=tf_graph) as sess:
    iteration = 1
    loss = 0
    sess.run(tf.global_variables_initializer())

    print("Begin training-----------")
    for e in range(1, epochs+1):
        batches = skipG_batch_creation(train_words, batch_length, word_window)
        start = time.time()
        for x, y in batches:

            train_loss, _ = sess.run([cost_fn, optim],
                                     feed_dict={input_:x, label_:np.array(y)[:,None]})
            loss += train_loss

            if iteration % 100 ==0:
                end = time.time()
                print('Epoch {}/{}'.format(e,epochs),
                      ', Iteration:{}'.format(iteration),
                      ', Avg.Training loss:{:.4f}'.format(loss/100),
                      ', Processing:{:.4f} sec/batch'.format((end-start)/100))
                loss = 0
                start = time.time()

            if iteration % 2000 ==0:
                similarity_ = word_similarity.eval()  # 返回结果值
                for i in range(validation_cnt):
                    validated_words = rev_dictionary[validation_words[i]]
                    top_k = 8
                    nearest = (-similarity_[i,:]).argsort()[1:top_k+1]  # argsort将similarity_中的元素从小到大排列,提取其对应的index(索引)
                    log = 'Nearest to %s:' % validated_words
                    for k in range(top_k):
                        close_word = rev_dictionary[nearest[k]]
                        log = '%s %s,' % (log, close_word)
                    print(log)
            iteration += 1  # 每遍历一个batch,iteration值加1

    save_path = saver.save(sess, "model_checkpoint/skipGram_text8.ckpt")
    embed_mat = sess.run(normalization_embed)

with tf_graph.as_default():
    saver = tf.train.Saver()

with tf.Session(graph=tf_graph) as sess:
    saver.restore(sess, tf.train.latest_checkpoint('model_checkpoint'))
    embed_mat = sess.run(word_embed)

# 使用t分布随机邻嵌入(t-SNE)来实现可视化
word_graph = 250
tsne = TSNE()
word_embedding_tsne = tsne.fit_transform(embed_mat[:word_graph,:])

可视化结果:

                                   

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

满腹的小不甘

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值