在tensorflow2.x的keras中内置了7种类型的数据集:
数据集名称 | 数据集描述 |
---|---|
boston_housing | 波士顿房价数据 |
cifar10 | 10种类别图片集 |
cifar100 | 100种类别图片集 |
fashion_mnist | 10种时尚类别图片集 |
imdb | 电影评论情感分类数据集 |
mnist | 手写数字图片集 |
reuters | 路透社新闻主题分类数据集 |
这些数据的读取都可以使用load_data()方法。不过2种关于文本的数据集imdb和reuters比较特殊,他们的load_data中包含了过滤参数。本文将介绍imdb的load_data参数以及用法。
imdb.load_data的定义如下:
tf.keras.datasets.imdb.load_data(
path='imdb.npz', num_words=None, skip_top=0, maxlen=None, seed=113,
start_char=1, oov_char=2, index_from=3, **kwargs
)
-
path
此参数定义的文件的名称。一般使用默认值 -
num_words
整数。定义的是大于该词频的单词会被读取。如果单词的词频小于该整数,会用oov_char定义的数字代替。默认是用2代替。
需要注意的是,词频越高的单词,其在单词表中的位置越靠前,也就是索引越小,所以实际上保留下来的是索引小于此数值的单词。 -
skip_top
整数。词频低于此整数的单词会被读入。高于此整数的会被oov_char定义的数字代替。 -
maxlen
整数。评论单词数小于此数值的会被读入。比如一条评论包含的单词数是120,如果maxlen=100,则该条评论不会被读入。 -
seed
整数。定义了随机打乱数据时候的初始化种子。跟生成随机数的种子是一样的。 -
start_char
整数。定义了每条评论的起始索引。默认值是1。 -
oov_char
整数。定义了不满足条件单词的替代值。凡是不满足过滤条件的单词的索引都用此数值代替。 -
index_from
整数。单词索引大于此数值的会被读入。 -
**kwargs
兼容用途。 -
num_words使用
示例如下:
from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print('第一个评论内容:',x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(num_words=100)
print("前100词频:",len(x100),' 第一个评论【100】:',len(x100[0]))
print('第一个评论内容【100】:',x100[0][0:10])
结果如下:
全部数据: 25000 第一个评论: 218
第一个评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
前100词频: 25000 第一个评论【100】: 218
第一个评论内容【100】: [1, 14, 22, 16, 43, 2, 2, 2, 2, 65]
对比可以发现,索引大于100的都被2替代了。
- skip_top使用
示例如下:
from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print('第一个评论内容:',x[0]