tfrecord + keras + estimator

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from keras import backend as K

_BATCH_SIZE =10

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(256, activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer='adam', metrics=['accuracy'])


def _tf_example_parser(record):
    feature = {"image/encoded": tf.FixedLenFeature([], tf.string),
               "image/class_id": tf.FixedLenFeature([], tf.int64)}
    features = tf.parse_single_example(record, features=feature)
    image = tf.decode_raw(features["image/encoded"], out_type=tf.uint8)   # 写入tfrecords的时候是misc/cv2读取的ndarray
    image = tf.cast(image, dtype=tf.float32)
    image = tf.reshape(image, shape=(28, 28, 1))    # 如果输入图片不做resize,那么不同大小的图片是无法输入到同一个batch中的
    label = tf.cast(features["image/class_id"], dtype=tf.int64)
    return image, label

def input_fn(data_path, batch_size=64, is_training=True):
    with K.name_scope("input_pipeline"):
        if not isinstance(data_path, (tuple, list)):
            data_path = [data_path]
        dataset = tf.data.TFRecordDataset(data_path)
        dataset = dataset.map(_tf_example_parser)
        # dataset = dataset.repeat(25)              # num of epochs
        dataset = dataset.batch(10)               # batch size
        if is_training:
            dataset = dataset.shuffle(1000)     # 对输入进行shuffle,buffer_size越大,内存占用越大,shuffle的时间也越长,因此可以在写tfrecords的时候就实现用乱序写入,这样的话这里就不需要用shuffle
        iterator = dataset.make_one_shot_iterator()
        images, labels = iterator.get_next()
        # convert to onehot label
        labels = tf.one_hot(labels, 1)  # 二分类
        # preprocess image: scale pixel values from 0-255 to 0-1
        images = tf.image.convert_image_dtype(images, dtype=tf.float32)  # 将图片像素从0-255转换成0-1,tf提供的图像操作大多需要0-1之间的float32类型
        images /= 255.
        images -= 0.5
        images *= 2.
        # return dict({"input_1": images}), labels
        return images, labels

## 结合公司项目 将这个程序跑通

source_dir = "/Users/chengzheng/GraphRelated/tf_repos/keras_estimator"
if tf.gfile.Exists(source_dir):
    print(source_dir+"/train*.tfrecord")
    train_data_paths = tf.gfile.Glob(source_dir+"/train*")  # 所有train开头的tfrecords都用于模型训练
    # val_data_paths = tf.gfile.Glob(source_dir+"/valid*")       # 所有val开头的tfrecords都用于模型评估
    val_data_paths = train_data_paths
    if not len(train_data_paths):
        raise Exception("[Train Error]: unable to find train*.tfrecord file")
    # if not len(val_data_paths):
    #     raise Exception("[Eval Error]: unable to find val*.tfrecord file")
else:
    raise Exception("[Train Error]: unable to find input directory!")
est_model = tf.keras.estimator.model_to_estimator(keras_model=model)
train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(data_path=train_data_paths,
                                                              batch_size=_BATCH_SIZE,
                                                              is_training=True),
                                    max_steps=300000)

eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(val_data_paths,
                                                            batch_size=_BATCH_SIZE,
                                                            is_training=False))
tf.reset_default_graph()
# train and evaluate model
tf.estimator.train_and_evaluate(estimator=est_model,
                                train_spec=train_spec,
                                eval_spec=eval_spec)
print("----------------------")
result = est_model.predict(input_fn=lambda: input_fn(val_data_paths))
for r in result:
    print(r)

Tensorflow读写TFRecords文件
TensorFlow 1.4利用Keras+Estimator API进行训练和预测
tensorflow estimator 与 model_fn 是这样沟通的
ValueError: This model has not yet been built. Build the model first by calling build() or calling
TensorFlow Estimator 教程之----快速入门
keras调试的正确打开方式: 一句话让你把tensorflow当pytorch用
将Keras模型转化为Estimator模型
Tensorflow-keras实战(九):Estimator实战

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值