Tensorflow生成自己的图片数据集TFrecords

Tensorflow生成自己的图片数据集TFrecords

train.txt保存图片的路径和标签信息

  1. test_image/dog/ 1.jpg 0
  2. test_image/dog/ 2.jpg 0
  3. test_image/dog/ 3.jpg 0
  4. test_image/dog/ 4.jpg 0
  5. test_image/cat/ 1.jpg 1
  6. test_image/cat/ 2.jpg 1
  7. test_image/cat/ 3.jpg 1
  8. test_image/cat/ 4.jpg 1

使用下面完整代码,可以生成自己的图片数据集TFrecords,并解析:

  1. # -*- coding: utf-8 -*-
  2. # !/usr/bin/python3.5
  3. # ref url : https://blog.csdn.net/guyuealian/article/details/80857228
  4. # Author : pan_jinquan
  5. # Date : 2018.6.29
  6. # Function: image convert to tfrecords
  7. #############################################################################################
  8. import tensorflow as tf
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. from PIL import Image
  12. # 参数设置
  13. ###############################################################################################
  14. train_file = 'train.txt' # 图片路径
  15. output_record_dir= './tfrecords/my_record.tfrecords'
  16. resize_height = 100 # 指定存储图片高度
  17. resize_width = 100 # 指定存储图片宽度
  18. ###############################################################################################
  19. # 生成整数型的属性
  20. def _int64_feature(value):
  21. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  22. # 生成字符串型的属性
  23. def _bytes_feature(value):
  24. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  25. # 生成实数型的属性
  26. def float_list_feature(value):
  27. return tf.train.Feature(float_list=tf.train.FloatList(value=value))
  28. # 显示图片
  29. def show_image(image_name,image):
  30. # plt.figure("show_image") # 图像窗口名称
  31. plt.imshow(image)
  32. plt.axis( 'on') # 关掉坐标轴为 off
  33. plt.title(image_name) # 图像题目
  34. plt.show()
  35. #载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签,如:test_image/1.jpg 0
  36. def load_txt_file(examples_list_file):
  37. lines = np.genfromtxt(examples_list_file, delimiter= " ", dtype=[( 'col1', 'S120'), ( 'col2', 'i8')])
  38. examples = []
  39. labels = []
  40. for example, label in lines:
  41. examples.append(example.decode( 'ascii'))
  42. labels.append(label)
  43. return np.asarray(examples), np.asarray(labels), len(lines)
  44. # 读取原始图像数据
  45. def read_image(filename, resize_height, resize_width):
  46. # image = cv2.imread(filename)
  47. # image = cv2.resize(image, (resize_height, resize_width))
  48. # b, g, r = cv2.split(image)
  49. # rgb_image = cv2.merge([r, g, b])
  50. rgb_image=Image.open(filename)
  51. rgb_image=rgb_image.resize((resize_width,resize_height))
  52. image=np.asanyarray(rgb_image)
  53. # show_image("src resize image",image)
  54. return image
  55. #保存record文件
  56. def create_record(train_file, output_record_dir, resize_height, resize_width):
  57. _examples, _labels, examples_num = load_txt_file(train_file)
  58. writer = tf.python_io.TFRecordWriter(output_record_dir)
  59. for i, [example, label] in enumerate(zip(_examples, _labels)):
  60. print( 'No.%d' % (i))
  61. image = read_image(example, resize_height, resize_width)
  62. print( 'shape: %d, %d, %d, label: %d' % (image.shape[ 0], image.shape[ 1], image.shape[ 2], label))
  63. image_raw = image.tostring()
  64. example = tf.train.Example(features=tf.train.Features(feature={
  65. 'image_raw': _bytes_feature(image_raw),
  66. 'height': _int64_feature(image.shape[ 0]),
  67. 'width': _int64_feature(image.shape[ 1]),
  68. 'depth': _int64_feature(image.shape[ 2]),
  69. 'label': _int64_feature(label)
  70. }))
  71. writer.write(example.SerializeToString())
  72. writer.close()
  73. #解析record文件,并显示,主要用于验证
  74. def disp_records(tfrecord_list_file):
  75. filename_queue = tf.train.string_input_producer([tfrecord_list_file])
  76. reader = tf.TFRecordReader()
  77. _, serialized_example = reader.read(filename_queue)
  78. features = tf.parse_single_example(
  79. serialized_example,
  80. features={
  81. 'image_raw': tf.FixedLenFeature([], tf.string),
  82. 'height': tf.FixedLenFeature([], tf.int64),
  83. 'width': tf.FixedLenFeature([], tf.int64),
  84. 'depth': tf.FixedLenFeature([], tf.int64),
  85. 'label': tf.FixedLenFeature([], tf.int64)
  86. }
  87. )
  88. image = tf.decode_raw(features[ 'image_raw'], tf.uint8)
  89. # print(repr(image))
  90. height = features[ 'height']
  91. width = features[ 'width']
  92. depth = features[ 'depth']
  93. label = tf.cast(features[ 'label'], tf.int32)
  94. init_op = tf.initialize_all_variables()
  95. resultImg = []
  96. resultLabel = []
  97. with tf.Session() as sess:
  98. sess.run(init_op)
  99. coord = tf.train.Coordinator()
  100. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  101. for i in range( 4):
  102. image_eval = image.eval()
  103. resultLabel.append(label.eval())
  104. image_eval_reshape = image_eval.reshape([height.eval(), width.eval(), depth.eval()])
  105. print( "image.shape=",height.eval(), width.eval(), depth.eval())
  106. resultImg.append(image_eval_reshape)
  107. # pilimg = Image.fromarray(np.asarray(image_eval_reshape))
  108. # pilimg.show()
  109. show_image( "decode_from_tfrecords",image_eval_reshape)
  110. coord.request_stop()
  111. coord.join(threads)
  112. sess.close()
  113. return resultImg, resultLabel
  114. #解析record文件
  115. def read_record(filename_queuetemp):
  116. filename_queue = tf.train.string_input_producer([filename_queuetemp])
  117. reader = tf.TFRecordReader()
  118. _, serialized_example = reader.read(filename_queue)
  119. features = tf.parse_single_example(
  120. serialized_example,
  121. features={
  122. 'image_raw': tf.FixedLenFeature([], tf.string),
  123. 'width': tf.FixedLenFeature([], tf.int64),
  124. 'depth': tf.FixedLenFeature([], tf.int64),
  125. 'label': tf.FixedLenFeature([], tf.int64)
  126. }
  127. )
  128. image = tf.decode_raw(features[ 'image_raw'], tf.uint8) #获得图像原始的数据
  129. image=tf.reshape(image, [ 100, 100, 3]) # 设置图像的维度
  130. # image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 # 归一化
  131. label = tf.cast(features[ 'label'], tf.int32) # label
  132. return image, label
  133. def test():
  134. create_record(train_file, output_record_dir, resize_height, resize_width) # 产生record文件
  135. # img, label = disp_records(output_record_dir) # 显示函数
  136. image_data, label_data= read_record(output_record_dir) # 读取函数
  137. #使用shuffle_batch可以随机打乱输入
  138. image_batch, label_batch = tf.train.shuffle_batch([image_data, label_data], batch_size= 30, capacity= 2000,min_after_dequeue= 1000)
  139. init = tf.global_variables_initializer()
  140. with tf.Session() as sess: #开始一个会话
  141. sess.run(init)
  142. coord = tf.train.Coordinator()
  143. threads = tf.train.start_queue_runners(coord=coord)
  144. for i in range( 4):
  145. images, labels= sess.run([image_batch, label_batch]) #在会话中取出image和label
  146. #我们也可以根据需要对val, l进行处理
  147. #l = to_categorical(l, 12)
  148. show_image( "image",images[i,:,:,:])
  149. print(images.shape, labels)
  150. #停止所有线程
  151. coord.request_stop()
  152. coord.join(threads)
  153. sess.close() #关闭会话
  154. if __name__ == '__main__':
  155. test()

假设我们已经生成了output.tfrecords,其中保存有:

  1. image_raw:图像的特征向量,有 784
  2. pixels:图像分辨率大小 28
  3. label:图像的标签

 可以使用下面的方法,进行训练网络:

  1. #coding=utf-8
  2. import tensorflow as tf
  3. # 模型相关的参数
  4. INPUT_NODE = 784
  5. OUTPUT_NODE = 10
  6. LAYER1_NODE = 500
  7. REGULARAZTION_RATE = 0.0001
  8. TRAINING_STEPS = 5000
  9. files = tf.train.match_filenames_once( "./output.tfrecords")
  10. filename_queue = tf.train.string_input_producer(files, shuffle= False)
  11. # 读取文件。
  12. reader = tf.TFRecordReader()
  13. _,serialized_example = reader.read(filename_queue)
  14. # 解析读取的样例。
  15. features = tf.parse_single_example(
  16. serialized_example,
  17. features={
  18. 'image_raw':tf.FixedLenFeature([],tf.string),
  19. 'pixels':tf.FixedLenFeature([],tf.int64),
  20. 'label':tf.FixedLenFeature([],tf.int64)
  21. })
  22. decoded_images = tf.decode_raw(features[ 'image_raw'],tf.uint8)
  23. retyped_images = tf.cast(decoded_images, tf.float32)
  24. labels = tf.cast(features[ 'label'],tf.int32)
  25. #pixels = tf.cast(features['pixels'],tf.int32)
  26. images = tf.reshape(retyped_images, [ 784])
  27. min_after_dequeue = 10000
  28. batch_size = 100
  29. capacity = min_after_dequeue + 3 * batch_size
  30. image_batch, label_batch = tf.train.shuffle_batch([images, labels],
  31. batch_size=batch_size,
  32. capacity=capacity,
  33. min_after_dequeue=min_after_dequeue)
  34. def inference(input_tensor, weights1, biases1, weights2, biases2):
  35. layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
  36. return tf.matmul(layer1, weights2) + biases2
  37. weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev= 0.1))
  38. biases1 = tf.Variable(tf.constant( 0.1, shape=[LAYER1_NODE]))
  39. weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev= 0.1))
  40. biases2 = tf.Variable(tf.constant( 0.1, shape=[OUTPUT_NODE]))
  41. y = inference(image_batch, weights1, biases1, weights2, biases2)
  42. # 计算交叉熵及其平均值
  43. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=label_batch)
  44. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  45. # 损失函数的计算
  46. regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
  47. regularaztion = regularizer(weights1) + regularizer(weights2)
  48. loss = cross_entropy_mean + regularaztion
  49. # 优化损失函数
  50. train_step = tf.train.GradientDescentOptimizer( 0.01).minimize(loss)
  51. # 初始化回话并开始训练过程。
  52. with tf.Session() as sess:
  53. tf.global_variables_initializer().run()
  54. coord = tf.train.Coordinator()
  55. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  56. # 循环的训练神经网络。
  57. for i in range(TRAINING_STEPS):
  58. if i % 1000 == 0:
  59. print( "After %d training step(s), loss is %g " % (i, sess.run(loss)))
  60. sess.run(train_step)
  61. coord.request_stop()
  62. coord.join(threads)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值