在TensorFlow中,有些数据源使用Python的库,有的需要编写Python脚本下载,还有其他的得手动从网上下载。
1、鸢尾花卉数据集(Iris data)。此样本数据是机器学习和统计分析最经典的数据集,包含鸢尾、变色鸢尾和维吉尼亚鸢尾各自的花萼和花瓣的长度和宽度。总共有150个数据集,每类有50个样本。可以使用Scikit Learn的数据集函数:
用法:
from sklearn import datasets
iris = datasets.load_iris()
print(len(iris.data))
print(len(iris.target))
print(iris.data[0])
print(set(iris.target))
2、出生体重数据(Birth weight data),此样本数据是婴儿出生体重以及母亲和家庭历史人口统计学、医学指标,有189个样本集,包含11个特征变量。
import requests
birthdata_url = 'https://www.umass.edu/statdata/statdata/data/lowbwt.dat'
birth_file = requests.get(birthdata_url)
birth_data = birth_file.text.split('\r\n')[5:]
birth_header = [x for x in birth_data[0].split(' ') if len(x)>=1]
birth_data = [[float(x) for x in y.split(' ') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
print(len(birth_data))
print(len(birth_data[0]))
3、波士顿房价数据(Boston Housing data)。此样本数据集保存在卡耐基梅隆大学机器学习仓库,总共有506个房价样本,包含14个特征变量。
import requests
housing_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data'
housing_header = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT', 'MEDV']
housing_file = requests.get(housing_url)
housing_data = [[float(x) for x in y.split(' ') if len(x)>=1] for y in housing_file.text.split('\n') if len(y)>=1]
print(len(housing_data))
print(len(housing_data[0]))
4、MNIST手写体字库:MNIST手写体字库是NIST手写体字库的字样本数据集。包含70000张0到9的图像,其中60000张标注为训练数据样本集,10000张为测试样本数据集。TensorFlow通过内部函数访问它,MNIST手写字体库常用来进行图像识别训练。在机器学习中,提供验证样本数据集来预防过拟合是非常重要的,TensorFlow从训练样本数据集中留出5000张图作为验证样本数据集合。
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
print(len(mnist.train.images))
print(len(mnist.test.images))
print(len(mnist.validation.images))
print(mnist.train.labels[1,:])
5、垃圾短信文本数据集(Spam-ham text data)。
import requests
import io
from zipfile import ZipFile
zip_url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'
r = requests.get(zip_url)
z = ZipFile(io.BytesIO(r.content))
file = z.read('SMSSpamCollection')
text_data = file.decode()
text_data = text_data.encode('ascii',errors='ignore')
text_data = text_data.decode().split('\n')
text_data = [x.split('\t') for x in text_data if len(x)>=1]
[text_data_target, text_data_train] = [list(x) for x in zip(*text_data)]
print(len(text_data_train))
print(set(text_data_target))
print(text_data_train[1])
6、影评样本数据集。此样本数据集是电影观看者的影评,分为好评和差评。
import requests
import io
import tarfile
movie_data_url = 'http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz'
r = requests.get(movie_data_url)
stream_data = io.BytesIO(r.content)
tmp = io.BytesIO()
while True:
s = stream_data.read(16384)
if not s:
break
tmp.write(s)
stream_data.close()
tmp.seek(0)
tar_file = tarfile.open(fileobj=tmp, mode="r:gz")
pos = tar_file.extractfile('rt-polaritydata/rt-polarity.pos')
neg = tar_file.extractfile('rt-polaritydata/rt-polarity.neg')
pos_data = []
for line in pos:
pos_data.append(line.decode('ISO-8859-1').encode('ascii',errors='ignore').decode())
neg_data = []
for line in neg:
neg_data.append(line.decode('ISO-8859-1').encode('ascii',errors='ignore').decode())
tar_file.close()
print(len(pos_data))
print(len(neg_data))
print(neg_data[0])
7、CIFAR-10图像数据集。此图像数据集是CIFAR机构发布的8亿张彩色图片(已标注,32x32像素)的子集,总共分为10类,60000张图片。50000张图片训练数据集,10000张测试数据集。
8、莎士比亚著作文本数据集(Shakespeare text data)。
import requests
shakespeare_url = 'http://www.gutenberg.org/cache/epub/100/pg100.txt'
response = requests.get(shakespeare_url)
shakespeare_file = response.content
shakespeare_text = shakespeare_file.decode('utf-8')
shakespeare_text = shakespeare_text[7675:]
print(len(shakespeare_text))
9、英德句子翻译样本集。此数据集由Tatoeba发布,Manythings整理。
import requests
import io
from zipfile import ZipFile
sentence_url = 'http://www.manythings.org/anki/deu-eng.zip'
r = requests.get(sentence_url)
z = ZipFile(io.BytesIO(r.content))
file = z.read('deu.txt')
eng_ger_data = file.decode()
eng_ger_data = eng_ger_data.encode('ascii',errors='ignore')
eng_ger_data = eng_ger_data.decode().split('\n')
eng_ger_data = [x.split('\t') for x in eng_ger_data if len(x)>=1]
[english_sentence, german_sentence] = [list(x) for x in zip(*eng_ger_data)]
print(len(english_sentence))
print(len(german_sentence))
print(eng_ger_data[10])