tensorflow入门:tfrecord 和tf.data.TFRecordDataset

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/yeqiustu/article/details/79793454

1.创建tfrecord

tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:


     
     
  1. tf.train.Feature(bytes_list=tf.train.BytesList( value=[feature.tostring()])) #feature一般是多维数组,要先转为list
  2. tf.train.Feature(int64_list=tf.train.Int64List( value=list(feature.shape))) #tostring函数后feature的形状信息会丢失,把shape也写入
  3. tf.train.Feature(float_list=tf.train.FloatList( value=[label]))

通过上述操作,以dict的形式把要写入的数据汇总,并构建tf.train.Features,然后构建tf.train.Example,如下:


     
     
  1. def get_tfrecords_example(feature, label):
  2. tfrecords_features = {}
  3. feat_shape = feature.shape
  4. tfrecords_features[ 'feature'] = tf.train.Feature(bytes_list=tf.train.BytesList( value=[feature.tostring()]))
  5. tfrecords_features[ 'shape'] = tf.train.Feature(int64_list=tf.train.Int64List( value=list(feat_shape)))
  6. tfrecords_features[ 'label'] = tf.train.Feature(float_list=tf.train.FloatList( value=label))
  7. return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

把创建的tf.train.Example序列化下,便可通过tf.python_io.TFRecordWriter写入tfrecord文件,如下:


     
     
  1. tfrecord_wrt = tf.python_io.TFRecordWriter( 'xxx.tfrecord') #创建tfrecord的writer,文件名为xxx
  2. exmp = get_tfrecords_example(feats[inx], labels[inx]) #把数据写入Example
  3. exmp_serial = exmp.SerializeToString()     #Example序列化
  4. tfrecord_wrt.write(exmp_serial)     #写入tfrecord文件
  5. tfrecord_wrt.close()     #写完后关闭tfrecord的writer

代码汇总:


     
     
  1. import tensorflow as tf
  2. from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
  3. mnist = read_data_sets( "MNIST_data/", one_hot= True)
  4. #把数据写入Example
  5. def get_tfrecords_example(feature, label):
  6. tfrecords_features = {}
  7. feat_shape = feature.shape
  8. tfrecords_features[ 'feature'] = tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
  9. tfrecords_features[ 'shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
  10. tfrecords_features[ 'label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
  11. return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
  12. #把所有数据写入tfrecord文件
  13. def make_tfrecord(data, outf_nm='mnist-train'):
  14. feats, labels = data
  15. outf_nm += '.tfrecord'
  16. tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
  17. ndatas = len(labels)
  18. for inx in range(ndatas):
  19. exmp = get_tfrecords_example(feats[inx], labels[inx])
  20. exmp_serial = exmp.SerializeToString()
  21. tfrecord_wrt.write(exmp_serial)
  22. tfrecord_wrt.close()
  23. import random
  24. nDatas = len(mnist.train.labels)
  25. inx_lst = range(nDatas)
  26. random.shuffle(inx_lst)
  27. random.shuffle(inx_lst)
  28. ntrains = int( 0.85*nDatas)
  29. # make training set
  30. data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
  31. [mnist.train.labels[i] for i in inx_lst[:ntrains]])
  32. make_tfrecord(data, outf_nm= 'mnist-train')
  33. # make validation set
  34. data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
  35. [mnist.train.labels[i] for i in inx_lst[ntrains:]])
  36. make_tfrecord(data, outf_nm= 'mnist-val')
  37. # make test set
  38. data = (mnist.test.images, mnist.test.labels)
  39. make_tfrecord(data, outf_nm= 'mnist-test')

2.tfrecord文件的使用:tf.data.TFRecordDataset

从tfrecord文件创建TFRecordDataset:

dataset = tf.data.TFRecordDataset('xxx.tfrecord')
     
     

解析tfrecord文件的每条记录,即序列化后的tf.train.Example;使用tf.parse_single_example来解析:

feats = tf.parse_single_example(serial_exmp, features=data_dict)
     
     

其中,data_dict是一个dict,包含的key是写入tfrecord文件时用的key,相应的value则是tf.FixedLenFeature([], tf.string)、tf.FixedLenFeature([], tf.int64)、tf.FixedLenFeature([], tf.float32),分别对应不同的数据类型,汇总即有:

def parse_exmp(serial_exmp):    #label中[10]是因为一个label是一个有10个元素的列表,shape中的[x]为shape的长度

     
     

     
     
  1. feats = tf.parse_single_example(serial_exmp, features={ 'feature':tf.FixedLenFeature([], tf. string),\
  2. 'label':tf.FixedLenFeature([ 10],tf. float32), 'shape':tf.FixedLenFeature([x], tf. int64)})
  3. image = tf.decode_raw(feats[ 'feature'], tf. float32)
  4. label = feats[ 'label']
  5. shape = tf.cast(feats[ 'shape'], tf. int32)
  6. return image, label, shape

解析tfrecord文件中的所有记录,使用dataset的map方法,如下:

dataset = dataset.map(parse_exmp)
     
     

map方法可以接受任意函数以对dataset中的数据进行处理;另外,可使用repeat、shuffle、batch方法对dataset进行重复、混洗、分批;用repeat复制dataset以进行多个epoch;如下:

dataset = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)
     
     

解析完数据后,便可以取出数据进行使用,通过创建iterator来进行,如下:


     
     
  1. iterator = dataset.make_one_shot_iterator()
  2. batch_image, batch_label, batch_shape = iterator.get_next()

要把不同dataset的数据feed进行模型,则需要先创建iterator handle,即iterator placeholder,如下:


     
     
  1. handle = tf.placeholder(tf. string, shape=[])
  2. iterator = tf.data.Iterator.from_string_handle(handle, \
  3. dataset_train.output_types, dataset_train.output_shapes)
  4. image, label, shape = iterator.get_next()

然后为各个dataset创建handle,以feed_dict传入placeholder,如下:


     
     
  1. with tf.Session() as sess:
  2. handle_train, handle_val, handle_test = sess.run(\
  3. [x.string_handle() for x in [iter_train, iter_val, iter_test]])
  4.         sess.run([loss, train_op], feed_dict={handle: handle_train}

汇总:


     
     
  1. import tensorflow as tf
  2. train_f, val_f, test_f = [ 'mnist-%s.tfrecord'%i for i in [ 'train', 'val', 'test']]
  3. def parse_exmp(serial_exmp):
  4. feats = tf.parse_single_example(serial_exmp, features={ 'feature':tf.FixedLenFeature([], tf.string),\
  5. 'label':tf.FixedLenFeature([ 10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
  6. image = tf.decode_raw(feats[ 'feature'], tf.float32)
  7. label = feats[ 'label']
  8. shape = tf.cast(feats[ 'shape'], tf.int32)
  9. return image, label, shape
  10. def get_dataset(fname):
  11. dataset = tf.data.TFRecordDataset(fname)
  12. return dataset.map(parse_exmp) # use padded_batch method if padding needed
  13. epochs = 16
  14. batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
  15. # there will be a batch data with nums less than batch_size
  16. # training dataset
  17. nDatasTrain = 46750
  18. dataset_train = get_dataset(train_f)
  19. dataset_train = dataset_train.repeat(epochs).shuffle( 1000).batch(batch_size) # make sure repeat is ahead batch
  20. # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
  21. # the latter means that there will be a batch data with nums less than batch_size for each epoch
  22. # if when batch_size can't be divided by nDatas.
  23. nBatchs = nDatasTrain*epochs//batch_size
  24. # evalation dataset
  25. nDatasVal = 8250
  26. dataset_val = get_dataset(val_f)
  27. dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs// 100* 2)
  28. # test dataset
  29. nDatasTest = 10000
  30. dataset_test = get_dataset(test_f)
  31. dataset_test = dataset_test.batch(nDatasTest)
  32. # make dataset iterator
  33. iter_train = dataset_train.make_one_shot_iterator()
  34. iter_val = dataset_val.make_one_shot_iterator()
  35. iter_test = dataset_test.make_one_shot_iterator()
  36. # make feedable iterator
  37. handle = tf.placeholder(tf.string, shape=[])
  38. iterator = tf.data.Iterator.from_string_handle(handle, \
  39. dataset_train.output_types, dataset_train.output_shapes)
  40. x, y_, _ = iterator.get_next()
  41. train_op, loss, eval_op = model(x, y_)
  42. init = tf.initialize_all_variables()
  43. # summary
  44. logdir = './logs/m4d2a'
  45. def summary_op(datapart='train'):
  46. tf.summary.scalar(datapart + '-loss', loss)
  47. tf.summary.scalar(datapart + '-eval', eval_op)
  48. return tf.summary.merge_all()
  49. summary_op_train = summary_op()
  50. summary_op_test = summary_op( 'val')
  51. with tf.Session() as sess:
  52. sess.run(init)
  53. handle_train, handle_val, handle_test = sess.run(\
  54. [x.string_handle() for x in [iter_train, iter_val, iter_test]])
  55.         _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
  56. feed_dict={handle: handle_train, keep_prob: 0.5} )
  57.         cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_test], \
  58. feed_dict={handle: handle_val, keep_prob: 1.0})

3.mnist实验


     
     
  1. import tensorflow as tf
  2. train_f, val_f, test_f = [ 'mnist-%s.tfrecord'%i for i in [ 'train', 'val', 'test']]
  3. def parse_exmp(serial_exmp):
  4. feats = tf.parse_single_example(serial_exmp, features={ 'feature':tf.FixedLenFeature([], tf.string),\
  5. 'label':tf.FixedLenFeature([ 10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
  6. image = tf.decode_raw(feats[ 'feature'], tf.float32)
  7. label = feats[ 'label']
  8. shape = tf.cast(feats[ 'shape'], tf.int32)
  9. return image, label, shape
  10. def get_dataset(fname):
  11. dataset = tf.data.TFRecordDataset(fname)
  12. return dataset.map(parse_exmp) # use padded_batch method if padding needed
  13. epochs = 16
  14. batch_size = 50 # when batch_size can't be divided by nDatas, like 56,
  15. # there will be a batch data with nums less than batch_size
  16. # training dataset
  17. nDatasTrain = 46750
  18. dataset_train = get_dataset(train_f)
  19. dataset_train = dataset_train.repeat(epochs).shuffle( 1000).batch(batch_size) # make sure repeat is ahead batch
  20. # this is different from dataset.shuffle(1000).batch(batch_size).repeat(epochs)
  21. # the latter means that there will be a batch data with nums less than batch_size for each epoch
  22. # if when batch_size can't be divided by nDatas.
  23. nBatchs = nDatasTrain*epochs//batch_size
  24. # evalation dataset
  25. nDatasVal = 8250
  26. dataset_val = get_dataset(val_f)
  27. dataset_val = dataset_val.batch(nDatasVal).repeat(nBatchs// 100* 2)
  28. # test dataset
  29. nDatasTest = 10000
  30. dataset_test = get_dataset(test_f)
  31. dataset_test = dataset_test.batch(nDatasTest)
  32. # make dataset iterator
  33. iter_train = dataset_train.make_one_shot_iterator()
  34. iter_val = dataset_val.make_one_shot_iterator()
  35. iter_test = dataset_test.make_one_shot_iterator()
  36. # make feedable iterator, i.e. iterator placeholder
  37. handle = tf.placeholder(tf.string, shape=[])
  38. iterator = tf.data.Iterator.from_string_handle(handle, \
  39. dataset_train.output_types, dataset_train.output_shapes)
  40. x, y_, _ = iterator.get_next()
  41. # cnn
  42. x_image = tf.reshape(x, [ -1, 28, 28, 1])
  43. w_init = tf.truncated_normal_initializer(stddev= 0.1, seed= 9)
  44. b_init = tf.constant_initializer( 0.1)
  45. cnn1 = tf.layers.conv2d(x_image, 32, ( 5, 5), padding= 'same', activation=tf.nn.relu, \
  46. kernel_initializer=w_init, bias_initializer=b_init)
  47. mxpl1 = tf.layers.max_pooling2d(cnn1, 2, strides= 2, padding= 'same')
  48. cnn2 = tf.layers.conv2d(mxpl1, 64, ( 5, 5), padding= 'same', activation=tf.nn.relu, \
  49. kernel_initializer=w_init, bias_initializer=b_init)
  50. mxpl2 = tf.layers.max_pooling2d(cnn2, 2, strides= 2, padding= 'same')
  51. mxpl2_flat = tf.reshape(mxpl2, [ -1, 7* 7* 64])
  52. fc1 = tf.layers.dense(mxpl2_flat, 1024, activation=tf.nn.relu, \
  53. kernel_initializer=w_init, bias_initializer=b_init)
  54. keep_prob = tf.placeholder( 'float')
  55. fc1_drop = tf.nn.dropout(fc1, keep_prob)
  56. logits = tf.layers.dense(fc1_drop, 10, kernel_initializer=w_init, bias_initializer=b_init)
  57. loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))
  58. optmz = tf.train.AdamOptimizer( 1e-4)
  59. train_op = optmz.minimize(loss)
  60. def get_eval_op(logits, labels):
  61. corr_prd = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
  62. return tf.reduce_mean(tf.cast(corr_prd, 'float'))
  63. eval_op = get_eval_op(logits, y_)
  64. init = tf.initialize_all_variables()
  65. # summary
  66. logdir = './logs/m4d2a'
  67. def summary_op(datapart='train'):
  68. tf.summary.scalar(datapart + '-loss', loss)
  69. tf.summary.scalar(datapart + '-eval', eval_op)
  70. return tf.summary.merge_all()
  71. summary_op_train = summary_op()
  72. summary_op_val = summary_op( 'val')
  73. # whether to restore or not
  74. ckpts_dir = 'ckpts/'
  75. ckpt_nm = 'cnn-ckpt'
  76. saver = tf.train.Saver(max_to_keep= 50) # defaults to save all variables, using dict {'x':x,...} to save specified ones.
  77. restore_step = ''
  78. start_step = 0
  79. train_steps = nBatchs
  80. best_loss = 1e6
  81. best_step = 0
  82. # import os
  83. # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  84. # config = tf.ConfigProto()
  85. # config.gpu_options.per_process_gpu_memory_fraction = 0.9
  86. # config.gpu_options.allow_growth=True # allocate when needed
  87. # with tf.Session(config=config) as sess:
  88. with tf.Session() as sess:
  89. sess.run(init)
  90. handle_train, handle_val, handle_test = sess.run(\
  91. [x.string_handle() for x in [iter_train, iter_val, iter_test]])
  92. if restore_step:
  93. ckpt = tf.train.get_checkpoint_state(ckpts_dir)
  94. if ckpt and ckpt.model_checkpoint_path: # ckpt.model_checkpoint_path means the latest ckpt
  95. if restore_step == 'latest':
  96. ckpt_f = tf.train.latest_checkpoint(ckpts_dir)
  97. start_step = int(ckpt_f.split( '-')[ -1]) + 1
  98. else:
  99. ckpt_f = ckpts_dir+ckpt_nm+ '-'+restore_step
  100. print( 'loading wgt file: '+ ckpt_f)
  101. saver.restore(sess, ckpt_f)
  102. summary_wrt = tf.summary.FileWriter(logdir,sess.graph)
  103. if restore_step in [ '', 'latest']:
  104. for i in range(start_step, train_steps):
  105. _, cur_loss, cur_train_eval, summary = sess.run([train_op, loss, eval_op, summary_op_train], \
  106. feed_dict={handle: handle_train, keep_prob: 0.5} )
  107. # log to stdout and eval validation set
  108. if i % 100 == 0 or i == train_steps -1:
  109. saver.save(sess, ckpts_dir+ckpt_nm, global_step=i) # save variables
  110. summary_wrt.add_summary(summary, global_step=i)
  111. cur_val_loss, cur_val_eval, summary = sess.run([loss, eval_op, summary_op_val], \
  112. feed_dict={handle: handle_val, keep_prob: 1.0})
  113. if cur_val_loss < best_loss:
  114. best_loss = cur_val_loss
  115. best_step = i
  116. summary_wrt.add_summary(summary, global_step=i)
  117. print 'step %5d: loss %.5f, acc %.5f --- loss val %0.5f, acc val %.5f'%(i, \
  118. cur_loss, cur_train_eval, cur_val_loss, cur_val_eval)
  119. # sess.run(init_train)
  120. with open(ckpts_dir+ 'best.step', 'w') as f:
  121. f.write( 'best step is %d\n'%best_step)
  122. print 'best step is %d'%best_step
  123. # eval test set
  124. test_loss, test_eval = sess.run([loss, eval_op], feed_dict={handle: handle_test, keep_prob: 1.0})
  125. print 'eval test: loss %.5f, acc %.5f'%(test_loss, test_eval)

实验结果:






  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值