代码抄写自《tensorflow实战》一书,以便大家运行测试学习。
#coding:utf-8
#因为要下载数据,所以导入的依赖库比较多
import collections
import math
import os
import random
import zipfile
import numpy as np
import urllib
import sys
import tensorflow as tf
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
#这边是python版本的一个检查,不同版本对应函数调用的接口是不一样的
# if sys.version_info[0] >= 3:
# from urllib.request import urlretrieve
# else:
# from urllib import urlretrieve
#从网址下载数据并检查数据的准确性
url = 'http://mattmahoney.net/dc/'
def maybe_download(filename, excepted_bytes):
if not os.path.exists(filename):
filename, _ = urlretrieve(url + filename, filename)
statinfo = os.stat(filename)
if statinfo.st_size == excepted_bytes:
print("Found and verified", filename)
else:
print(statinfo.st_size)
raise Exception(
"Failed to verfy" + filename + "Can you get to it with browser?")
return filename
filename = maybe_download('text8.zip', 31344016)
#定义读取数据的函数,并把数据转成列表
def read_data(filename):
with zipfile.ZipFile(filename) as f:
data = tf.compat.as_str(f.read(f.namelist()[0])).split()
return data
words = read_data(filename)
print('Data size', len(words))
#创建词汇表,选取前50000频数的单词,其余单词认定为Unknown,编号为0
vocabulary_size = 50000
def build_dataset(words):
count = [['UNK', -1]]
count.extend(collections.Counter(words).most_common(vocabulary_size - 1