直接上代码,然后底下补充注意事项。亲测可用
#coding:utf-8
import tensorflow as tf
import os
import os.path
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
rootdir = "G:\\ZhangSG\\TFRecords\\indoor_scene"
TFRfilename = "G:\\ZhangSG\\TFRecords\\indoor_scene.tfrecords"
defined_label = [
'airport_inside',
'artstudio',
'auditorium',
'bakery',
'bar',
'bathroom',
'bedroom',
'bookstore',
'bowling',
'buffet',
'casino',
'children_room',
'church_inside',
'classroom',
'cloister',
'closet',
'clothingstore',
'computerroom',
'concert_hall',
'corridor',
'deli',
'dentaloffice',
'dining_room',
'elevator',
'fastfood_restaurant',
'florist',
'gameroom',
'garage',
'greenhouse',
'grocerystore',
'gym',
'hairsalon',
'hospitalroom',
'inside_bus',
'inside_subway',
'jewelleryshop',
'kindergarden',
'kitchen',
'laboratorywet',
'laundromat',
'library',
'livingroom',
'lobby',
'locker_room',
'mall',
'meeting_room',
'movietheater',
'museum',
'nursery',
'office',
'operating_room',
'pantry',
'poolinside',
'prisoncell',
'restaurant',
'restaurant_kitchen',
'shoeshop',
'stairscase',
'studiomusic',
'subway',
'toystore',
'trainstation',
'tv_studio',
'videostore',
'waitingroom',
'warehouse',
'winecellar']
# get the labelID (0 ~ category_num -1) or -1 if label not found
def convert_filename_to_labelID(filename,defined_label):
# get the label numbers
label_num = len(defined_label)
labelid = -1;
# loop the defined labels to find the label name that matches current filename
for i in range(0,label_num):
if defined_label[i] in filename:
labelid=i
break
return labelid
writer = tf.python_io.TFRecordWriter(TFRfilename)
count=0
with tf.Session() as sess:
for parent,dirnames,filenames in os.walk(rootdir): #三个参数:分别返回1.父目录 2.所有文件夹名字(不含路径) 3.所有文件名字
for filename in filenames: #输出文件信息
if "jpg" in filename:
labelID = convert_filename_to_labelID(filename,defined_label)
if (labelID>=0) and (labelID<len(defined_label)):
image_dir = parent+"\\"+filename
image_raw_data_jpg = tf.gfile.FastGFile(image_dir, 'rb').read()
img_data_jpg = tf.image.decode_jpeg(image_raw_data_jpg)
img_data_jpg = tf.image.convert_image_dtype(img_data_jpg, dtype=tf.float32)
resized_image = tf.image.resize_images(img_data_jpg, [200, 200])
image_raw_data = sess.run(tf.cast(resized_image, tf.uint8)).tobytes()
if(len(image_raw_data)==0):
continue
example = tf.train.Example(features=tf.train.Features(feature={
# 包装为可以训练的数据
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labelID])),
'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw_data]))
}))
count=count+1
print("文件"+filename+"生成成功,已生成%d个文件"%count)
writer.write(example.SerializeToString())
writer.close()
print ("TFRecord文件已保存。共%d个文件"%count)
如果想用这段代码,需要改动几个地方:
1. os.environ["CUDA_VISIBLE_DEVICES"] = "1" 如果只有一个GPU,这句话不用要。
2. root_dir是存放训练集(或者测试集)图片的地方,也就是待生成tfrecords文件的那些图片。TFRfilename是生成的tfrecords文件所在的路径和文件名。本条注意事项,我这两个路径都写成了绝对路径,可以按照自己需求改动。
3. 保证所有图片的名字都含有标签。如果不含有,简单方法是选中该分类下所有文件,全选,右键,直接输入分类名,如kitchen,可以看到全部文件自动重命名为 kitchen (1)等。
4. 把defined_label改为你自己的测试集分类
5. resized_image = tf.image.resize_images(img_data_jpg, [200, 200])把两个200改成你想resize成的高和宽(高和宽顺序不要弄反了,我这也不确定,如果不放心可以参考上一篇文章 http://blog.csdn.net/zsg2063/article/details/75646394,resize出来看看效果)。
6. 有些地方喜欢把 image_raw_data_jpg = tf.gfile.FastGFile(image_dir, 'rb').read() 里面的rb写成r,我测试过程中r有问题,这里建议写成rb,具体原因没调查。
7. 这份代码目前仅限于jpeg文件,png文件没研究。