输入为中文文本,采用网上下载的小说,训练词向量
首先处理文本,输入小说原txt文件,生成data,count,dictionary和reversed_dictionary文件
import jieba
import re
import collections
import os
filename="D:\wordforchinese.txt"
vocabulary_size = 100
def getfiletxt():
if not os.path.exists(filename):
print("can not find "+filename)
return
data=list()
with open(filename,encoding='ANSI') as f:
temp=f.read()
words = re.sub("[\s+’!“”\"#$%&\'(())*,,\-.。·/::;;《》、<=>?@[\\]【】^_`{|}…~]+","", temp)
words_split=jieba.cut(words)
for i,word in enumerate(words_split):
data.append(word)
return data
def build_dataset(words):
count = [['Unknow', 0]]
count.extend(collections.Counter(words).most_common(vocabulary_size-1))
dictionary=dict()
for word,num in count:
dictionary[word]=len(dictionary)
data=list()
for word in words:
if word in dictionary:
index=dictionary[word]
else:
index=0
count[0][1] += 1
data.append(index)
reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
return data,count,dictionary,reversed_dictionary
if __name__ == "__main__":
words=getfiletxt()
data, count, dictionary, reversed_dictionary = build_dataset(words)
del words
with open("D:/data1.txt",'w') as f1:
f1.write(str(data))
f1.close()
print("data list has been written")
with open("D:/count1.txt",'w') as f2:
f2.write(str(count))
f2.close()
print("count list has been written")
with open("D:/dictionary1.txt",'w') as f3:
f3.write(str(dictionary))
f3.close()
print("dictionary has been written")
with open("D:/reversed_dictionary1.txt",'w') as f4:
f4.write(str(reversed_dictionary))
f4.close()
print("reversed_dictionary has been written")
然后训练
# encoding=utf8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import math
import os
import random
import zipfile
import numpy as np
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
with open("D:/data.txt",'r') as f1:
data=eval(f1.read())
f1.close()
with open("D:/count.txt",'r') as f2:
count = eval(f2.read())
f2.close()
with open("D:/dictionary.txt",'r') as f3:
dictionary = eval(f3.read())
f3.close()
with open("D:/reversed_dictionary.txt",'r') as f4:
reverse_dictionary = eval(f4.read())
f4.close()
print("数据已加载")
data_index = 0
vocabulary_size = 10000
def generate_batch(batch_size, num_skips, skip_window):
global data_index
assert batch_size % num_skips == 0
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=(batch_size), dtype=np.int32)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
span = 2 * skip_window + 1 # [ skip_window target skip_window ]
buffer = collections.deque(maxlen=span)
# [ skip_window target skip_window ]
# [ skip_window target skip_window ]
# [ skip_win