keras 训练使用tfrecord.

1.写入tfrecord


import tensorflow as tf
import os
import numpy as np

import PIL.Image as Image

def _get_path_label(image_dir):
    image_dir = os.path.expanduser(image_dir)
    ford_list = []
    for ford in os.listdir(image_dir):
        for sub_ford in os.listdir(os.path.join(image_dir, ford)):
            ford_list.append((ford + "\\" + sub_ford))

    ids = ford_list  # [0:50000]#list(os.listdir(image_dir))
    ids.sort()
    cat_num = len(ids)
    # logger.info("the total people number is {}".format(cat_num))
    id_dict = dict(zip(ids, list(range(cat_num))))
    paths = []
    labels = []
    for i in ids:
        cur_dir = os.path.join(image_dir, i)
        fns = os.listdir(cur_dir)
        paths.extend([os.path.join(cur_dir, fn) for fn in fns])
        labels.extend([id_dict[i]] * len(fns))

    _perm = np.random.permutation(np.arange(len(paths)))
    # _perm = np.arange(len(paths))
    shuffle_paths = []
    shuffle_labels = []
    for i in range(len(paths)):
        shuffle_paths.append(paths[_perm[i]])
        shuffle_labels.append(labels[_perm[i]])
    return shuffle_paths, shuffle_labels, cat_num


def _byteslist(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64list(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def creat_train_record(train_dir , train_record_path):
    writer = tf.compat.v1.python_io.TFRecordWriter(train_record_path)

    shuffle_paths, shuffle_labels, cat_num = _get_path_label(train_dir)

    for index, image_name in enumerate(shuffle_paths):
        img = Image.open(image_name)
        img_raw = img.tobytes()
        example = tf.train.Example(
            features=tf.train.Features(feature={
                'label': _int64list(shuffle_labels[index]),
                'img_raw': _byteslist(img_raw)}))
        writer.write(example.SerializeToString())
        if index % 1000 == 0:
            print("current i is:", index, " all data is:", len(shuffle_labels))
    writer.close()
    print('creat_train_record success !')

    print("cat_num:", cat_num)

creat_train_record(r'E:\dataset\face\train\temp', r'E:\dataset\face\traindata.tfrecords')

2.read_tfRecord

import tensorflow as tf
import os
from PIL import Image
import matplotlib.pyplot as plt


class ReadTfRecord:
    def __init__(self, filename, batch_size):
        self.filename = filename
        self.cat_num = 1000
        #filename_queue = tf.train.string_input_producer([filename])
        reader = tf.data.TFRecordDataset(filename)
        reader = reader.repeat(1)
        self.features={
                'label': tf.io.FixedLenFeature([], tf.int64),
                'img_raw': tf.io.FixedLenFeature([], tf.string)}

        reader = reader.map(self._parse_function)  # 解析数据
        self.batch = reader.batch(batch_size=batch_size)  # 每10条数据为一个batch,生成一个新的Dataset

    def _parse_function(self, exam_proto):
        return tf.io.parse_single_example(exam_proto, self.features)

    def get_data(self, item):
        label_shape = item['label']
        data_batch = item['img_raw']

        # for data in data_batch:
        img = tf.io.decode_raw(data_batch, tf.uint8)
        img = tf.reshape(img, (-1, 128, 128, 3))
        img = tf.cast(img, tf.float32) / 255.0
        return img, label_shape




#TFRecord =  ReadTfRecord(r'E:\dataset\face_recognition_train\data.tfrecords', 96)

# for i in range(5):
#     for item in TFRecord.batch:
#
#         img, label = TFRecord.get_data(item)
#
#
#         print(img.shape, " ", label.shape)

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

NineDays66

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值