Tensorflow之Basic word2vec代码详解(上)

Tensorflow上关于Vector Representations of Words里给出了word2vec两个源代码,本文解析基础的代码,代码地址为:https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/tutorials/word2vec/word2vec_basic.py

上篇为代码step1-3:讲解数据下载处理,与训练数据的生成。

from __future__ import absolute_import  
from __future__ import division  
from __future__ import print_function  
  
import collections  
import math  
import os  
import random  
import zipfile  
  
import numpy as np  
from six.moves import urllib  
from six.moves import xrange  # pylint: disable=redefined-builtin  
import tensorflow as tf  
  
# Step 1: 下载数据  
url = 'http://mattmahoney.net/dc/'  
  
  
def maybe_download(filename, expected_bytes):  
  """Download a file if not present, and make sure it's the right size."""  
  if not os.path.exists(filename):  
    filename, _ = urllib.request.urlretrieve(url + filename, filename)  
  statinfo = os.stat(filename)  
  if statinfo.st_size == expected_bytes:  
    print('Found and verified', filename)  
  else:  
    print(statinfo.st_size)  
    raise Exception(  
        'Failed to verify ' + filename + '. Can you get to it with a browser?')  
  return filename  
#  
filename = maybe_download('text8.zip', 31344016)  
  
  
# 解压缩并读取数据转化到数组中.  
def read_data(filename):  
  """Extract the first file enclosed in a zip file as a list of words."""  
  with zipfile.ZipFile(filename) as f:  
    data = tf.compat.as_str(f.read(f.namelist()[0])).split() #split(分割成序列) 
  return data  
  
vocabulary = read_data(filename)  
print('Data size', len(vocabulary))  

#建立字典  
# Step 2: Build the dictionary and replace rare words with UNK token.  
vocabulary_size = 50000  

def build_dataset(words, n_words):  
  """Process raw inputs into a dataset."""  
  count = [['UNK', -1]]  
  count.extend(collections.Counter(words).most_common(n_words - 1))#计数,取词频前50000个词,其余的为unk,  
  dictionary = dict()  
  for word, _ in count:  
    dictionary[word] = len(dictionary)#高频词排序给编号  
  data = list()  
  unk_count = 0  
  for word in words:  
    if word in dictionary:  
      index = dictionary[word]#给高频词一个索引  
    else:  
      index = 0  # 低频词索引为0  
      unk_count += 1  #统计低频词的个数
    data.append(index)  
  count[0][1] = unk_count  
  reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys())) #逆词汇,键和值与dictionary相反 
  return data, count, dictionary, reversed_dictionary  
  
data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,  
                                                            vocabulary_size)  
del vocabulary  # Hint to reduce memory.  
print('Most common words (+UNK)', count[:5])  
print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]])  

#如vocabulary(daefbmc……)其中:1.a词频:600;2.b词频:500;3.c词频:400;4.d词频:300;5.e词频:200;6.f词频:100;……unk:4148
#count([UNK,4148],[a,600],[b,500],[c,400],[d,300],[e,200,[f,100])
#dictionary([a;1],[b:2],[c:3],[d:4],[e;5],[f;6])
#data:{4,1,5,6,2,0,3}
#reversed_dictionary([1;a],[2:b],[3:c],[4:d],[5;e],[6;f])


#生成训练数据
#从文本总体的第二次开始,每个单词一次作为输入,输出可以是上下文范围内的单词中的任何一个(一般不是取全部而是随机抽取其中几组,增加随机性)  
data_index = 0   
# Step 3: Function to generate a training batch for the skip-gram model.  
def generate_batch(batch_size, num_skips, skip_window):  
#batch_size:每次训练的词长度;num_skips:每个输入词重复的次数(一个输入产生多少个标签数据),skip_window:向左右取多少词
  global data_index  #global:全局变量
  assert batch_size % num_skips == 0  #assert:断言
  assert num_skips <= 2 * skip_window  
  
  batch = np.ndarray(shape=(batch_size), dtype=np.int32)  #一维
  labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)  #二维
  span = 2 * skip_window + 1  # 左右两边各取skip_window,一共span个 
  buffer = collections.deque(maxlen=span)  #防止超长,挤出前面的数据,确保span个训练数据
  if data_index + span > len(data):  #依次取span个词
    data_index = 0  
  buffer.extend(data[data_index:data_index + span])  
  data_index += span  
  for i in range(batch_size // num_skips):  
    target = skip_window  #butter[skip_window]为输入数据
    targets_to_avoid = [skip_window] #去除输入词自己本身 
    for j in range(num_skips):  #输入词重复num_skips次
      while target in targets_to_avoid:  
        target = random.randint(0, span - 1) #随机生成 (0, span - 1)之间整数
      targets_to_avoid.append(target)  
      batch[i * num_skips + j] = buffer[skip_window] #训练输入的序列 
      labels[i * num_skips + j, 0] = buffer[target]  #训练输出的序列(标签,对应词频的排序)
    if data_index == len(data):  #超长时回到开始
      buffer[:] = data[:span]  
      data_index = span  
    else:  
      buffer.append(data[data_index]) #挤掉开始几个,换一组词训练 
      data_index += 1  
  # Backtrack a little bit to avoid skipping words in the end of a batch  
  data_index = (data_index + len(data) - span) % len(data)  
  return batch, labels  
  
batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)  
for i in range(8):  
  print(batch[i], reverse_dictionary[batch[i]],  
        '->', labels[i, 0], reverse_dictionary[labels[i, 0]])  
 
#如:vocabulary(m|daefbm|c……)取batch_size=6,num_skips=2,skip_window=1
#batch = [4,4,1,1,5,5]
#labels= [0,1,4,5,1,6]



  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值