TensorFlow本地导入imdb数据集的方法

写在前面

最近想体验一下AI方面的技术,但是在做一些TensorFlow官网上的简单实验时却发现由于网络、环境等问题,里面的示例代码并不能很顺利地跑在我的机器上,主要原因出在从远端自动load_data的时候,总是会报https、ssl等网络相关的问题。之前的Fashion MNIST,通过网上的一些攻略,找到了本地导入的方法,于是在进行下一个imdb(电影评论文本分类)的小实验时,想自己写一个类似功能的函数来本地导入数据集,于是就有了本文,代码和思路供参考。(另外,我发现keras中的imdb这部分代码并不长,于是把另外一个也需要联网的imdb.get_word_index()也顺便一起改成离线版了,反正网络问题要出现就都是一起出现的…)

我的环境

操作系统:CentOS 7
Conda版本:从TUNA下载的Anaconda3-5.3.1-Linux-x86_64.sh
TensorFlow版本:1.15.4
Keras:Tensorflow中自带的

使用方法

  1. 下载imdb相关的两个文件(imdb.npz, imdb_word_index.json)到本地目录(例如/home/user/datasets/TF_imdb)。度盘下载链接 / 提取码:ic54

  2. 复制下方小节的【离线导入代码块】到.py文件中,修改代码中标出的路径,定位到刚才下载的两个文件。(如果缺少相应的包,可以用pip install

  3. 把官网示例代码中的imdb.load_data(num_words=10000)修改成我们自定义的manual_imdb_load_data(num_words=10000);同理,imdb.get_word_index()修改成manual_imdb_get_word_index()

  4. 其他代码可以和官网示例中保持一致,运行测试即可。

离线导入代码块】:由于是从官方代码中改的,有些注释我保留了。用之前记得先把下载的两个文件放在对应路径下。

import tensorflow as tf
from tensorflow import keras
import numpy as np

print(tf.__version__)	# 简单检测一下tensorflow的版本

# 修改imdb的一些函数(因为https有问题)
from tensorflow.python.platform import tf_logging as logging
import json
def manual_imdb_load_data(path='imdb.npz',
              num_words=None,
              skip_top=0,
              maxlen=None,
              seed=113,
              start_char=1,
              oov_char=2,
              index_from=3,
              **kwargs):
    if 'nb_words' in kwargs:
        logging.warning('The `nb_words` argument in `load_data` '
                        'has been renamed `num_words`.')
        num_words = kwargs.pop('nb_words')
    if kwargs:
        raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))

    path = '/home/user/datasets/TF_imdb/imdb.npz'	# 【修改点1】
    with np.load(path, allow_pickle=True) as f:
        x_train, labels_train = f['x_train'], f['y_train']
        x_test, labels_test = f['x_test'], f['y_test']

    rng = np.random.RandomState(seed)
    indices = np.arange(len(x_train))
    rng.shuffle(indices)
    x_train = x_train[indices]
    labels_train = labels_train[indices]

    indices = np.arange(len(x_test))
    rng.shuffle(indices)
    x_test = x_test[indices]
    labels_test = labels_test[indices]

    if start_char is not None:
        x_train = [[start_char] + [w + index_from for w in x] for x in x_train]
        x_test = [[start_char] + [w + index_from for w in x] for x in x_test]
    elif index_from:
        x_train = [[w + index_from for w in x] for x in x_train]
        x_test = [[w + index_from for w in x] for x in x_test]

    if maxlen:
        x_train, labels_train = keras.preprocessing.sequence._remove_long_seq(maxlen, x_train, labels_train)
        x_test, labels_test = keras.preprocessing.sequence._remove_long_seq(maxlen, x_test, labels_test)
        if not x_train or not x_test:
            raise ValueError('After filtering for sequences shorter than maxlen=' +
                             str(maxlen) + ', no sequence was kept. '
                                           'Increase maxlen.')

    xs = np.concatenate([x_train, x_test])
    labels = np.concatenate([labels_train, labels_test])

    if not num_words:
        num_words = max(max(x) for x in xs)

    # by convention, use 2 as OOV word
    # reserve 'index_from' (=3 by default) characters:
    # 0 (padding), 1 (start), 2 (OOV)
    if oov_char is not None:
        xs = [
            [w if (skip_top <= w < num_words) else oov_char for w in x] for x in xs
        ]
    else:
        xs = [[w for w in x if skip_top <= w < num_words] for x in xs]

    idx = len(x_train)
    x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
    x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])

    return (x_train, y_train), (x_test, y_test)

def manual_imdb_get_word_index(path='imdb_word_index.json'):
  """Retrieves a dict mapping words to their index in the IMDB dataset.

  Arguments:
      path: where to cache the data (relative to `~/.keras/dataset`).

  Returns:
      The word index dictionary. Keys are word strings, values are their index.
  """
  path = '/home/user/datasets/TF_imdb/imdb_word_index.json'	# 【修改点2】
  with open(path) as f:
    return json.load(f)
#------------------------------offline loading preparation-------------------------------#
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值