这里写自定义目录标题
#TFReord的结构和原理
我看到的一篇文章写的很好,链接:https://xie.infoq.cn/article/70a763d42850678f8acfa38e9
import os
import glob
from datetime import datetime
import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tensorflow as tf
image_path='./img/'
images=glob.glob(image_path+'*.jpg')
print(images)
image_labels={
'dog':0,
'kangaroo':1
}
#制作tfrecord
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
# BytesList won't unpack a string from an EagerTensor.
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def image_example(image_string,label):
"""
Creates a tf.train.Example message ready to be written to a file.
"""
# Create a dictionary mapping the feature name to the tf.train.Example-compatible
image_shape=tf.image.decode_jpeg(image_string).shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label':_int64_feature(label),
'iamge_raw': _bytes_feature(image_string)
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(
feature=feature))
return example_proto
record_file='image_tfrecord'
counter=0
with tf.io.TFRecordWriter(record_file) as writer:
for fname in images:
with open(fname,'rb') as f:
image_bgr=f.read()
image_shape = tf.image.decode_jpeg(image_bgr).shape
#print(os.path.basename(fname))
label=int(os.path.basename(fname).replace('.jpg',''))
tf_example=image_example(image_string=image_bgr,label=label)
writer.write(tf_example.SerializeToString())
counter+=1
print('process {:d} of {:d} images.'.format(counter,len(images)))
#加载创造好的tfrecord
dataset=tf.data.TFRecordDataset('image_tfrecord')
print(dataset)
#example数据都进行了序列化,还需要解析以前写的序列化的string
#解析的格式需要和之前创建的example对应起来
image_feature_description={
'height': tf.io.FixedLenFeature([],tf.int64),
'width':tf.io.FixedLenFeature([],tf.int64),
'depth': tf.io.FixedLenFeature([],tf.int64),
'label':tf.io.FixedLenFeature([],tf.int64),
'iamge_raw': tf.io.FixedLenFeature([],tf.string),
}
def parse_tf_example(example_proto):
#解析出阿里
parse_example=tf.io.parse_single_example(example_proto,image_feature_description)
#预处理
x_train=tf.image.decode_jpeg(parse_example['iamge_raw'],channels=3)
x_train=tf.image.resize(x_train,(416,416))
x_train/=255
label=parse_example['label']
y_train=label
return x_train,y_train
train_dataset=dataset.map(parse_tf_example)
print(train_dataset)
#显示数据集
for x in enumerate(train_dataset):
print(x)