from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import conv_set as c
import read_data as r
from PIL import Image
import conv_set as c
import numpy as np
import matplotlib.pyplot as plt
TRAIN_DATA_PATH = './train.tfrecords'
TEST_DATA_PATH = './train.tfrecords'
LEARNING_RATE = 1e-5
BATCH_SIZE = 1
BATCH_SIZE_TEST = 10000
MODEL_DIR = './model/'
NUM_EPOCHS = 10
xs = tf.placeholder(tf.float32, [1, 28, 28, 1]) # 输入图像
ys = tf.placeholder(tf.float32, [None, 10]) # 输入分类结果
with tf.variable_scope('reshape'):
x_image = tf.reshape(xs, [28, 28]) # 28*28*1
# 迭代器,读数据
with tf.variable_scope('input_data'):
train_dataset = r.read_and_decode(TRAIN_DATA_PATH, BATCH_SIZE, NUM_EPOCHS) # 读数据
train_iterator = train_dataset.make_initializable_iterator()
sess = tf.InteractiveSession()
sess.run(train_iterator.initializer)
img_batch, label_batch = train_iterator.get_next()
sess.run(tf.global_variables_initializer())
def method1(img, sess):
x_result2 = sess.run(x_image, feed_dict={xs: img})
return x_result2
def method2(img, sess):
x_rsp = tf.reshape(img, [28, 28])
x_result = sess.run(x_rsp)
return x_result
with tf.variable_scope('writer'):
for i in range(100):
img, label = sess.run([img_batch, label_batch])
x_result = method1(img, sess)
I = Image.fromarray(x_result,'L')
I.save('./Save_test' + '/' + str(i) + '_''Label_''.jpg') # 存下图片
# plt.title('conv1')
# plt.imshow(I)
# plt.show()
# writer = tf.summary.FileWriter('./log', sess.graph) # 保存tensorboard
# writer.add_summary(img_scalar)
Method1的结果
Method2的结果
原因:
I = Image.fromarray(x_result,‘L’)
改为:
I = Image.fromarray(x_result)
但是这样保存会报错,无法识别是jpg格式,只能用plt方法show出来
或者:
xs = tf.placeholder(tf.float32, [1, 28, 28, 1]) # 输入图像
改为
xs = tf.placeholder(tf.uint8, [1, 28, 28, 1]) # 输入图像 (因为保存tfrecord用uint8)