解决Tensorflow中文社区MNIST机器学习入门里使用示例代码无法连接服务器的问题

今天上课开始讲神经网络,看着我的python3.6,caffe不装了,换tensorflow,gpu版折腾几个小时也装不上——放弃,cpu版先用着。

跟着tensorflow的中文文档先学下,但是很不给力啊,示例就一堆事:

http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html

首先这个input_data.py就根本下载不下来,好在百度一下就发现有很多小伙伴备份过了,我也贴一下:

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
我马上试了试:

等了一会儿....根本连不上服务器...被墙啦?我架着梯子呢!可能人品不好,再试一次,还是连不上服务器,这不是坑爹吗!

能不能不从服务器下载啊,毕竟这四个训练集文件是可以手动下载的。

我纠结了一下,自己改代码吧。
先找到主代码的位置:

tensorflow.contrib.learn.python.learn.datasets.mnist.py 里面的:read_data_sets这个函数
果然呐,里面有各种下载的信息,那就把它们删了,直接读文件。

以下是我自己改的,大家有需求就直接复制吧:

四个数据集文件后缀改成gz,直接放在工程根目录里就行。

from tensorflow.contrib.learn.python.learn.datasets.mnist import *


def read_data_setss(fake_data=False,
                    one_hot=False,
                    dtype=dtypes.float32,
                    reshape=True,
                    validation_size=5000,
                    seed=None, ):
    if fake_data:
        def fake():
            return DataSet(
                [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)

        train = fake()
        validation = fake()
        test = fake()
        return base.Datasets(train=train, validation=validation, test=test)
    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

    with gfile.Open(TRAIN_IMAGES, 'rb') as f:
        train_images = extract_images(f)

    with gfile.Open(TRAIN_LABELS, 'rb') as f:
        train_labels = extract_labels(f, one_hot=one_hot)

    with gfile.Open(TEST_IMAGES, 'rb') as f:
        test_images = extract_images(f)

    with gfile.Open(TEST_LABELS, 'rb') as f:
        test_labels = extract_labels(f, one_hot=one_hot)

    if not 0 <= validation_size <= len(train_images):
        raise ValueError('Validation size should be between 0 and {}. Received: {}.'
                         .format(len(train_images), validation_size))

    validation_images = train_images[:validation_size]
    validation_labels = train_labels[:validation_size]
    train_images = train_images[validation_size:]
    train_labels = train_labels[validation_size:]

    options = dict(dtype=dtype, reshape=reshape, seed=seed)

    train = DataSet(train_images, train_labels, **options)
    validation = DataSet(validation_images, validation_labels, **options)
    test = DataSet(test_images, test_labels, **options)

    return base.Datasets(train=train, validation=validation, test=test)


mnists = read_data_setss(one_hot=True)



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值