今天上课开始讲神经网络,看着我的python3.6,caffe不装了,换tensorflow,gpu版折腾几个小时也装不上——放弃,cpu版先用着。
跟着tensorflow的中文文档先学下,但是很不给力啊,示例就一堆事:
http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
首先这个input_data.py就根本下载不下来,好在百度一下就发现有很多小伙伴备份过了,我也贴一下:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
我马上试了试:
等了一会儿....根本连不上服务器...被墙啦?我架着梯子呢!可能人品不好,再试一次,还是连不上服务器,这不是坑爹吗!
能不能不从服务器下载啊,毕竟这四个训练集文件是可以手动下载的。
我纠结了一下,自己改代码吧。
先找到主代码的位置:
tensorflow.contrib.learn.python.learn.datasets.mnist.py 里面的:read_data_sets这个函数
果然呐,里面有各种下载的信息,那就把它们删了,直接读文件。
以下是我自己改的,大家有需求就直接复制吧:
四个数据集文件后缀改成gz,直接放在工程根目录里就行。
from tensorflow.contrib.learn.python.learn.datasets.mnist import *
def read_data_setss(fake_data=False,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
validation_size=5000,
seed=None, ):
if fake_data:
def fake():
return DataSet(
[], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)
train = fake()
validation = fake()
test = fake()
return base.Datasets(train=train, validation=validation, test=test)
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
with gfile.Open(TRAIN_IMAGES, 'rb') as f:
train_images = extract_images(f)
with gfile.Open(TRAIN_LABELS, 'rb') as f:
train_labels = extract_labels(f, one_hot=one_hot)
with gfile.Open(TEST_IMAGES, 'rb') as f:
test_images = extract_images(f)
with gfile.Open(TEST_LABELS, 'rb') as f:
test_labels = extract_labels(f, one_hot=one_hot)
if not 0 <= validation_size <= len(train_images):
raise ValueError('Validation size should be between 0 and {}. Received: {}.'
.format(len(train_images), validation_size))
validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
options = dict(dtype=dtype, reshape=reshape, seed=seed)
train = DataSet(train_images, train_labels, **options)
validation = DataSet(validation_images, validation_labels, **options)
test = DataSet(test_images, test_labels, **options)
return base.Datasets(train=train, validation=validation, test=test)
mnists = read_data_setss(one_hot=True)