最近接触TensorFlow,需要训练自己的数据集,看到很多博客资料,了解到TensorFlow中自带的tfrecord文件,但是自己具体实现起来发现自己的情况与资料的一些不太一样,所以把自己遇到的问题归纳整理出来。新手一枚,水平有限,有许多问题的解决可能仅限于解决,代码并有优化,有些思路可能走了弯路,希望能跟大家交流。
1.问题1:对于多分类情况,怎么确定标签?
(1)多分类:大多资料中给出的是针对两种分类的情况,采用的是直接用class={class1 , class2}这种格式,但是对于很多类的话,依次写出类别有点麻烦,那么可以采用先定义一个列表classes = []来存储目录中所有的分类,比如对于字符识别,那么classes = {1,2,3,4...,A,B,C},然后用for index, name in enumerate(classes)将对应文件夹的名字与整数一一对应起来。
其中enumerate是python中的一个函数,目的会将index和classes中的name对应起来,比如classes = {1,2,3,A,...}那么index = {1, 2, 3, 4,..}并且与classes中的1,2,3 ,A这些对应。
为什么要这样做?因为在存入tfrecord的时候,标签一般用的是整型,当目录文件中包含A,B,或者字符串的时候要将其变为整型,我尝试过读入tfrecord的时候用tobyte格式,也就是直接用字符串的形式读入,但是会报错,也可能是我知识水平不够,没有找到正确的方法。
classes = [] for class1 in os.listdir(cwd): classes.append(class1) for index, name in enumerate(classes): class_path = cwd + name + '\\' for img_name in os.listdir(class_path): img_path = class_path + img_name # 每一个图片的地址
2.问题二:如何在读入的时候分数据集和测试集(其中测试集占50%)
(1)我采用的是在逐层访问文件夹的时候用两个字典(一对多)存入图片的标签和对应图片地址。
for class1 in os.listdir(cwd): classes.append(class1) for index, name in enumerate(classes): class_path = cwd + name + '\\' for img_name in os.listdir(class_path): img_path = class_path + img_name # 每一个图片的地址 m += 1 if(m % 5) == 0: len_testing_dataset += 1 testing_dataset[index].append(img_path) else: len_training_dataset += 1 training_dataset[index].append(img_path)
3.问题三:数据格式的变换
(1)从tfrecord中读出的数据格式是tensor格式,我之前跟着教程构建的带有计算图的CNN,它输入的数据格式和mnist数据集是一
样的,那么要将输出的tensor格式转化为与mnist数据集一样的格式,并且标签采用one-hot编码格式
def to_one_hot(classes, label): num_classes = len(classes) # print(num_classes) # print("label-----------",label) label_arr = np.zeros((num_classes)) # print("label_arr---------",label_arr) label_arr[label] += 1.0 # print("after change label_arr",label_arr) return label_arr
def importimg(imagepath,m,classes): #imagepath为读入的图片tfrecord的地址 #imagepath = "data_train.tfrecords" # min_after_dequeue = 15 # batch_size = 1 # capacity = min_after_dequeue + 3 * batch_size # print(imagepath) # print("m------------",m) print("开始读入数据----------------------------------") filename_queue = tf.train.string_input_producer([imagepath]) #读入流中 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string), }) # 取出包含image和label的feature对象 image = tf.decode_raw(features['img_raw'], tf.uint8) # print("从tfrecord文件中读取数据image", image) image = tf.reshape(image, [-1]) # print("after reshape of image-----------------",image) label = tf.cast(features['label'], tf.int32) # 在流中抛出label张量 with tf.Session() as sess: # 开始一个会话 init_op = tf.global_variables_initializer() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) labels = [] images = [] for i in range(m): print("第", i, "个",imagepath,"数据正在读取中") image1, label1 = sess.run([image, label]) # 在会话中取出image和label image = tf.cast(image1, tf.float32) label_arr = to_one_hot(classes, label1) labels.append(label_arr) images.append(image1) labels_arr = np.array(labels) images_arr = np.array(images) # print("labels_arr------------",labels_arr) # print("images_arr------------",images_arr) coord.request_stop() coord.join(threads) return images_arr, labels_arr
总的代码:
import os import tensorflow as tf from PIL import Image from collections import defaultdict from itertools import groupby #import matplotlib.pyplot as plt os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import numpy as np #读图片地址,CNN,已经测试正确 def read_image(cwd): #m记录样本数 m = 0 classes = [] len_testing_dataset = 0 len_training_dataset = 0 training_dataset = defaultdict(list) testing_dataset = defaultdict(list) for class1 in os.listdir(cwd): classes.append(class1) for index, name in enumerate(classes): class_path = cwd + name + '\\' for img_name in os.listdir(class_path): img_path = class_path + img_name # 每一个图片的地址 m += 1 if(m % 5) == 0: len_testing_dataset += 1 testing_dataset[index].append(img_path) else: len_training_dataset += 1 training_dataset[index].append(img_path) print("training_dataset testing_dataset END ------------------------------------------------------") return m, classes, training_dataset, testing_dataset, len_testing_dataset, len_training_dataset # m, classes, training_dataset, testing_dataset, len_testing_dataset, len_training_dataset = read_image( # 'E:\datafortest\Testlib1\\' # ) #CNN,写数据,已经测试正确 def write_data(dataset, newfilepath): writer = tf.python_io.TFRecordWriter(newfilepath) # 要生成的文件 for label, img in dataset.items(): for img_path in img: print("img_path------------",img_path) img = Image.open(img_path) img = img.resize((15, 15)) img_raw = img.tobytes() # 将图片转化为二进制格式,uint8 #img_decode = img_raw.decode('utf-8') #print(img_decode) example = tf.train.Example(features=tf.train.Features(feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) # example对象对label和image数据进行封装 writer.write(example.SerializeToString()) # 序列化为字符串 print("成功存入tfrecord文件") writer.close() # 返回总的类别数,和所有的类别标号 #CNN,写数据,已经测试正确 def write_mul_data(dataset, newfilepath, record_location): writer = None current_index = 0 for label, img in dataset.items(): for img_path in img: print("img_path------------", img_path) #每隔10000个就存入一个文件 if current_index % 10000 == 0: if writer: writer.close() record_filename = "{record_location} - {current_index}.tfrecords".format( record_location = record_location, current_index = current_index ) print(record_filename + "----------------------------") current_index += 1 image_file = tf.read_file(newfilepath) try: image = tf.image.decode_jpeg(newfilepath) except: print(image_file) continue # write_data(training_dataset,"train_set.tfrecords") def to_one_hot(classes, label): num_classes = len(classes) # print(num_classes) # print("label-----------",label) label_arr = np.zeros((num_classes)) # print("label_arr---------",label_arr) label_arr[label] += 1.0 # print("after change label_arr",label_arr) return label_arr #CNN,这个方法就是将tensor张量转化为images转化为int数组和label转化为ont-hot编码 def importimg(imagepath,m,classes): #imagepath为读入的图片tfrecord的地址 #imagepath = "data_train.tfrecords" # min_after_dequeue = 15 # batch_size = 1 # capacity = min_after_dequeue + 3 * batch_size # print(imagepath) # print("m------------",m) print("开始读入数据----------------------------------") filename_queue = tf.train.string_input_producer([imagepath]) #读入流中 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string), }) # 取出包含image和label的feature对象 image = tf.decode_raw(features['img_raw'], tf.uint8) # print("从tfrecord文件中读取数据image", image) image = tf.reshape(image, [-1]) # print("after reshape of image-----------------",image) label = tf.cast(features['label'], tf.int32) # 在流中抛出label张量 with tf.Session() as sess: # 开始一个会话 init_op = tf.global_variables_initializer() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) labels = [] images = [] for i in range(m): print("第", i, "个",imagepath,"数据正在读取中") image1, label1 = sess.run([image, label]) # 在会话中取出image和label image = tf.cast(image1, tf.float32) label_arr = to_one_hot(classes, label1) labels.append(label_arr) images.append(image1) labels_arr = np.array(labels) images_arr = np.array(images) # print("labels_arr------------",labels_arr) # print("images_arr------------",images_arr) coord.request_stop() coord.join(threads) return images_arr, labels_arr
参考资料:http://blog.csdn.net/xierhacker/article/details/72357651