import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from keras import backend as K
_BATCH_SIZE =10
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer='adam', metrics=['accuracy'])
def _tf_example_parser(record):
feature = {"image/encoded": tf.FixedLenFeature([], tf.string),
"image/class_id": tf.FixedLenFeature([], tf.int64)}
features = tf.parse_single_example(record, features=feature)
image = tf.decode_raw(features["image/encoded"], out_type=tf.uint8) # 写入tfrecords的时候是misc/cv2读取的ndarray
image = tf.cast(image, dtype=tf.float32)
image = tf.reshape(image, shape=(28, 28, 1)) # 如果输入图片不做resize,那么不同大小的图片是无法输入到同一个batch中的
label = tf.cast(features["image/class_id"], dtype=tf.int64)
return image, label
def input_fn(data_path, batch_size=64, is_training=True):
with K.name_scope("input_pipeline"):
if not isinstance(data_path, (tuple, list)):
data_path = [data_path]
dataset = tf.data.TFRecordDataset(data_path)
dataset = dataset.map(_tf_example_parser)
# dataset = dataset.repeat(25) # num of epochs
dataset = dataset.batch(10) # batch size
if is_training:
dataset = dataset.shuffle(1000) # 对输入进行shuffle,buffer_size越大,内存占用越大,shuffle的时间也越长,因此可以在写tfrecords的时候就实现用乱序写入,这样的话这里就不需要用shuffle
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
# convert to onehot label
labels = tf.one_hot(labels, 1) # 二分类
# preprocess image: scale pixel values from 0-255 to 0-1
images = tf.image.convert_image_dtype(images, dtype=tf.float32) # 将图片像素从0-255转换成0-1,tf提供的图像操作大多需要0-1之间的float32类型
images /= 255.
images -= 0.5
images *= 2.
# return dict({"input_1": images}), labels
return images, labels
## 结合公司项目 将这个程序跑通
source_dir = "/Users/chengzheng/GraphRelated/tf_repos/keras_estimator"
if tf.gfile.Exists(source_dir):
print(source_dir+"/train*.tfrecord")
train_data_paths = tf.gfile.Glob(source_dir+"/train*") # 所有train开头的tfrecords都用于模型训练
# val_data_paths = tf.gfile.Glob(source_dir+"/valid*") # 所有val开头的tfrecords都用于模型评估
val_data_paths = train_data_paths
if not len(train_data_paths):
raise Exception("[Train Error]: unable to find train*.tfrecord file")
# if not len(val_data_paths):
# raise Exception("[Eval Error]: unable to find val*.tfrecord file")
else:
raise Exception("[Train Error]: unable to find input directory!")
est_model = tf.keras.estimator.model_to_estimator(keras_model=model)
train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(data_path=train_data_paths,
batch_size=_BATCH_SIZE,
is_training=True),
max_steps=300000)
eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(val_data_paths,
batch_size=_BATCH_SIZE,
is_training=False))
tf.reset_default_graph()
# train and evaluate model
tf.estimator.train_and_evaluate(estimator=est_model,
train_spec=train_spec,
eval_spec=eval_spec)
print("----------------------")
result = est_model.predict(input_fn=lambda: input_fn(val_data_paths))
for r in result:
print(r)
Tensorflow读写TFRecords文件
TensorFlow 1.4利用Keras+Estimator API进行训练和预测
tensorflow estimator 与 model_fn 是这样沟通的
ValueError: This model has not yet been built. Build the model first by calling build()
or calling
TensorFlow Estimator 教程之----快速入门
keras调试的正确打开方式: 一句话让你把tensorflow当pytorch用
将Keras模型转化为Estimator模型
Tensorflow-keras实战(九):Estimator实战