1.加载数据MNIST_data,按照tensorflow官网的:
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
总是报错,应该查到安装tensorflow后,input_data.py这个文件在tensorflow的路径在tutorials下的mnist中,因此按如下import文件:
from tensorflow.examples.tutorials.mnist import input_data
由于在线下载mnist总是显示下载超时,所以建议在http://yann.lecun.com/exdb/mnist/上直接下载训练数据,格式为gz:
train-images-idx3-ubyte.gz: training set images (9912422 bytes)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
然后查看input_data.py中的源代码:
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
def read_data_sets(train_dir,
fake_data=False,
one_hot=False,
dtype=dtypes.float32,
reshape=True,
validation_size=5000,
seed=None,
source_url=DEFAULT_SOURCE_URL):
if fake_data:
def fake():
return DataSet(
[], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)
train = fake()
validation = fake()
test = fake()
return base.Datasets(train=train, validation=validation, test=test)
if not source_url: # empty string check
source_url = DEFAULT_SOURCE_URL
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
source_url + TRAIN_IMAGES)
with gfile.Open(local_file, 'rb') as f:
train_images = extract_images(f)
local_file = base.maybe_download(TRAIN_LABELS, train_dir,
source_url + TRAIN_LABELS)
with gfile.Open(local_file, 'rb') as f:
train_labels = extract_labels(f, one_hot=one_hot)
local_file = base.maybe_download(TEST_IMAGES, train_dir,
source_url + TEST_IMAGES)
with gfile.Open(local_file, 'rb') as f:
test_images = extract_images(f)
local_file = base.maybe_download(TEST_LABELS, train_dir,
source_url + TEST_LABELS)
with gfile.Open(local_file, 'rb') as f:
test_labels = extract_labels(f, one_hot=one_hot)
if not 0 <= validation_size <= len(train_images):
raise ValueError(
'Validation size should be between 0 and {}. Received: {}.'
.format(len(train_images), validation_size))
validation_images = train_images[:validation_size]
validation_labels = train_labels[:validation_size]
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
options = dict(dtype=dtype, reshape=reshape, seed=seed)
train = DataSet(train_images, train_labels, **options)
validation = DataSet(validation_images, validation_labels, **options)
test = DataSet(test_images, test_labels, **options)
return base.Datasets(train=train, validation=validation, test=test)
将source_url关闭(因为这个的地址DEFAULT_SOURCE_URL='https://storage.googleapis.com/cvdf-datasets/mnist/',其总是打不开),提示直接本地加载mnist data,注意,代码中MNIST_data/文件夹中需要有下载好的gz格式训练数据:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True,source_url = False)
这样就加载完成了。
2.下面是简单模型softmax regression建立的源代码:
# coding: utf-8
# In[14]:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# In[6]:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True,source_url = False)
# In[23]:
import numpy
print(mnist.train.images.shape)
# In[26]:
x = tf.placeholder("float", [None, 784]) #用浮点数来表示张量形状,每一张图展平为784维的向量
W = tf.Variable(tf.zeros([784,10])) # W 代表权重
b = tf.Variable(tf.zeros([10])) # b 偏置量
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,10]) # 新的占位符,用于输入正确值
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #计算交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #最小化成本值(交叉熵)
init = tf.initialize_all_variables() #初始化创建的变量
sess = tf.Session()
sess.run(init) #在session中启动模型,变量
# 训练模型1000次
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# In[29]:
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) #找最大值的索引值-即结果1
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))