TensorFlow 读取图片并写入tfrecord

图片下载地址:
链接:https://pan.baidu.com/s/1gvvr5ovcYT1pQTy0umpzrA
提取码:bh0t

import os
import tensorflow as tf 
from matplotlib import pyplot as plt
import numpy as np
from sklearn.utils import shuffle
from PIL import Image
%matplotlib inline
from tqdm import tqdm

print(tf.__version__)
print(np.__version__)
# 读取文件夹文件与标签
def load_sample(sample_dir):
    
    print("加载图片数据")
    file_name_list = []
    labels_names = []
    
    for(dir_path, dir_names, file_names) in os.walk(sample_dir):
        
        for file_name in file_names:
            file_path = os.path.join(dir_path, file_name)
            # 获取图片路径与文件夹名字(标签)
            file_name_list.append(file_path)    
            labels_names.append(dir_path.split("\\")[-1])
    
    lab = list(sorted(set(labels_names)))
    
    labdict = dict(zip(lab, list(range(len(lab)))))
    labels = [labdict[i] for i in labels_names]
    
    return (np.asarray(file_name_list), np.asarray(labels)), np.asarray(lab)
# 读取文件名与标签 
data_dir = 'man_woman\\'  # 定义文件路径

(images, labels), labelsnames = load_sample(data_dir)  # 载入文件名称与标签
print(len(images), images)  # 文件名 
print(len(labels), labels)  # 标签              
print(labelsnames)  # 标签字符串
def makeTFRec(filenames, labels):  # 定义函数生成TFRecord
    output_dir = "tfrecord_dir"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    filename = "mydata.tfrecords"
    filename_fullpath = os.path.join(output_dir, filename)
    
    with tf.io.TFRecordWriter(filename_fullpath) as writer:
        for i in tqdm(range(0, len(labels))):
            img = Image.open(filenames[i])
            img = img.resize((256, 256))
            img_raw = img.tobytes()  # 将图片转化为二进制格式
            
            features = tf.train.Features(feature = {
                "label":tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[labels[i]])),
                "img_raw":tf.train.Feature(
                    bytes_list = tf.train.BytesList(value=[img_raw]))
            })
            example = tf.train.Example(features=features)  # example对象对label和image数据进行封装

            writer.write(example.SerializeToString())  # 序列化为字符串
makeTFRec(images, labels)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

廷益--飞鸟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值