mnist手写数字识别python_tensorflow实现mnist手写数字识别报错?

代码:

import tensorflow as tf

import numpy as np

tf.enable_eager_execution()

class DataLoader():

def __init__(self):

mnist = tf.keras.datasets.mnist.load_data(path = 'mnist.npz')

self.train_data = mnist[0][0]

self.train_data = np.reshape(self.train_data,(self.train_data.shape[0],28*28))

self.train_labels = mnist[0][1]

self.eval_data = mnist[1][0]

self.train_data = np.reshape(self.train_data,(self.train_data.shape[0],28*28))

self.eval_labels = mnist[1][1]

def get_batch(self, batch_size):

indexs = np.random.randint(0,self.train_data.shape[0],batch_size)

return self.train_data[indexs, :], self.train_labels[indexs]

'''

class MLP(tf.keras.Modle):

'''

class MLP(tf.keras.Model):

def __init__(self):

super().__init__()

self.dense1 = tf.keras.layers.Dense(units=100, activation= tf.nn.relu)

self.dense2 = tf.keras.layers.Dense(units=10,activation =None)

def call(self, inputs):

x = self.dense1(inputs)

y = self.dense2(x)

return y

def predict(self, inputs):

logits = self(inputs)

return tf.argmax(logits, axis=-1)

num_batches = 1000

batch_size = 50

learning_rate = 0.001

model = MLP()

data_loader = DataLoader()

optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)

for batch_index in range(num_batches):

X , y = data_loader.get_batch(batch_size)

print(np.shape(X))

with tf.GradientTape() as tape:

X = tf.convert_to_tensor(X, dtype = tf.int64, name = 'X')

print(X)

y_logit_pred = model(X)

loss = tf.losses.sparse_softmax_cross_entropy(labels = y, logits = y_logit_pred)

print('batch %d: loss %f' % (batch_index, loss.numpy()))

grads = tape.gradient(loss, model.variables)

optimizer.apply_gradients(grads_and_vars = zip(grads, model.variables))

num_eval_samples = np.shape(data_loader.eval_labels)[0]

y_pred = model.predict(data_loader.eval_data).numpy()

print("test accuracy: %f" % (sum(y_pred == data_loader.eval_labels) / num_eval_samples))

错误信息:

/home/kalarea/.conda/envs/py35/bin/python /home/kalarea/PycharmProjects/start_tensorflow/start_tensorflow.py

/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from float to np.floating is deprecated. In future, it will be treated as np.float64 == np.dtype(float).type.

from ._conv import register_converters as _register_converters

(50, 784)

2018-10-14 18:28:18.977966: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA

tf.Tensor(

[[0 0 0 ... 0 0 0]

[0 0 0 ... 0 0 0]

[0 0 0 ... 0 0 0]

...

[0 0 0 ... 0 0 0]

[0 0 0 ... 0 0 0]

[0 0 0 ... 0 0 0]], shape=(50, 784), dtype=int64)

Traceback (most recent call last):

File "/home/kalarea/PycharmProjects/start_tensorflow/start_tensorflow.py", line 55, in

y_logit_pred = model(X)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 769, in call

outputs = self.call(inputs, *args, **kwargs)

File "/home/kalarea/PycharmProjects/start_tensorflow/start_tensorflow.py", line 30, in call

x = self.dense1(inputs)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 759, in call

self.build(input_shapes)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/layers/core.py", line 921, in build

trainable=True)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 586, in add_weight

aggregation=aggregation)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/training/checkpointable/base.py", line 591, in _add_variable_with_custom_getter

**kwargs_for_getter)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1986, in make_variable

aggregation=aggregation)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 145, in call

return cls._variable_call(*args, **kwargs)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 141, in _variable_call

aggregation=aggregation)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 120, in

previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 2434, in default_variable_creator

import_scope=import_scope)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/variables.py", line 147, in call

return super(VariableMetaclass, cls).__call__(*args, **kwargs)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 297, in init

constraint=constraint)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/resource_variable_ops.py", line 420, in _init_from_args

initial_value = initial_value()

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1970, in

shape, dtype=dtype, partition_info=partition_info)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/init_ops.py", line 483, in call

shape, -limit, limit, dtype, seed=self.seed)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/random_ops.py", line 240, in random_uniform

shape, minval, maxval, seed=seed1, seed2=seed2, name=name)

File "/home/kalarea/.conda/envs/py35/lib/python3.5/site-packages/tensorflow/python/ops/gen_random_ops.py", line 848, in random_uniform_int

_six.raise_from(_core._status_to_exception(e.code, message), None)

File "", line 3, in raise_from

tensorflow.python.framework.errors_impl.InvalidArgumentError: Need minval < maxval, got 0 >= 0 [Op:RandomUniformInt] name: mlp/dense/kernel/random_uniform/

Process finished with exit code 1

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值