tensorflow 2.0-根据图片制作数据集

33 篇文章 0 订阅
	#coding=utf-8
	
	import tensorflow as tf
	import numpy as np
	import matplotlib.pyplot as plt
	from PIL import Image
	import datetime
	import os
	
	
	root=os.path.abspath(os.path.dirname(__file__))
	os.chdir(root)
	
	
	image_train_path="train"
	#label_train_path="train.txt"
	tfRecord_train="data/train.tfrecords"
	
	image_test_path="test"
	#label_test_path="test.txt"
	tfRecord_test="data/test.tfrecords"
	
	resize_height=32
	resize_width=32
	data_path="data"
	
	
	def write_record(tf_record_name,image_path,label_data):
	    writer = tf.io.TFRecordWriter(tf_record_name)
	    pic_id=0
	 
	    for con in label_data:
	        img_path=image_path+"/"+con
	        img= Image.open(img_path)
	        img_str=img.tobytes()
	        label=label_data[con]
	
	        feature={
	            'img_str':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_str])),
	            'label':tf.train.Feature(bytes_list=tf.train.BytesList(value=[label.encode()]))
	            }
	
	        example=tf.train.Example(features=tf.train.Features(feature=feature))
	        writer.write(example.SerializeToString())
	        pic_id+=1
	        print ("the pic is:",pic_id)
	    writer.close()
	    print ("The data is writed successful!")
	            
	        
	
	
	def get_label_data(img_path):
	    if not os.path.exists(img_path):
	        print ("input data error!")
	        return
	    root =  os.listdir(img_path)
	    list_data={}
	    for r in root :
	        list_data[r]=r[0].upper()
	
	    return list_data
	        
	    
	def generate_tfRecord():
	
	    if not os.path.exists(data_path):
	        os.makedirs(data_path)
	        print ("root create successfully!")
	    else:
	        print ("root already exists!")
	
	    train_label_data=get_label_data(image_train_path)
	    #print (train_label_data)
	    write_record(tfRecord_train,image_train_path,train_label_data)
	    
	    test_label_data=get_label_data(image_test_path)
	    write_record(tfRecord_test,image_test_path,test_label_data)
	
	    
	
	
	
	
	
	def main():
	    generate_tfRecord()
	
	
	if __name__ == "__main__":
	    main()

result:
root already exists!
the pic is: 1
the pic is: 2
the pic is: 3
the pic is: 4
the pic is: 5
the pic is: 6
the pic is: 7
the pic is: 8
the pic is: 9
the pic is: 10
the pic is: 11
the pic is: 12
the pic is: 13
the pic is: 14
the pic is: 15
the pic is: 16
the pic is: 17
the pic is: 18
the pic is: 19
the pic is: 20
the pic is: 21
the pic is: 22
the pic is: 23
the pic is: 24
the pic is: 25
the pic is: 26
the pic is: 27
the pic is: 28
the pic is: 29
the pic is: 30
The data is writed successful!
the pic is: 1
the pic is: 2
the pic is: 3
the pic is: 4
the pic is: 5
the pic is: 6
The data is writed successful!

project root
project root
train data
在这里插入图片描述

test data

在这里插入图片描述
output file
在这里插入图片描述

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

佐倉

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值