tensorflow读取SVHN数据集转为TFrecords格式

       这里默认将python脚本文件和svhn数据集放在同一目录下,FLAGS.directory参数可以指定数据集的目录,由于svhn没有validation数据集,因此将train分割一部分出来作为validation。

注释:num_sample_size默认为10000,训练的样本数可以自己设,这里我设置20000,前15000作为训练,后5000作为验证

 

import argparse
import os
import sys
import tensorflow as tf
from scipy.io import loadmat

def data_set(data_dir, name, num_sample_size=10000):
    filename = os.path.join(data_dir, name + '_32x32.mat')
    if not os.path.isfile(filename):
        raise ValueError('Please supply a the file')
    #filename = os.path.join(data_dir,"train_32x32.mat")
    datadict = loadmat(filename)
    train_x = datadict['X']
    train_x = train_x.transpose((3, 0, 1, 2))
    train_y = datadict['y'].flatten()
    train_y[train_y==10]= 0
    train_x = train_x[:num_sample_size]
    train_y = train_y[:num_sample_size]
    return train_x,train_y

def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to_tfrecords(images, labels, fileName):
  num_examples, rows, cols, depth = images.shape
  
  print('Writing', fileName)
  writer = tf.python_io.TFRecordWriter(fileName)
  for index in range(num_examples):
    image_raw = images[index].tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(rows),
        'width': _int64_feature(cols),
        'depth': _int64_feature(depth),
        'label': _int64_feature(int(labels[index])),
        'image_raw': _bytes_feature(image_raw)}))
    writer.write(example.SerializeToString())
  writer.close()

def split_dataset(train_x, train_y, validation_size):
    return (train_x[:-validation_size],
            train_y[:-validation_size],
            train_x[-validation_size:],
            train_y[-validation_size:])

def main(unused_argv):
  train_x, train_y = data_set(FLAGS.directory, 'train')
  test_x, test_y = data_set(FLAGS.directory, 'test')
  train_x, train_y, valid_x, valid_y = split_dataset(
          train_x, train_y, FLAGS.validation_size)
  
  trainFileName = os.path.join(FLAGS.directory, 'train.tfrecords')
  validationFileName = os.path.join(FLAGS.directory, 'validation.tfrecords')
  testFileName = os.path.join(FLAGS.directory, 'test.tfrecords')

  # Convert to Examples and write the result to TFRecords.
  convert_to_tfrecords(train_x, train_y, trainFileName,num_sample_size=20000)
  convert_to_tfrecords(test_x, test_y, validationFileName)
  convert_to_tfrecords(valid_x, valid_y, testFileName)
 
  print('over')
  
if __name__ =='__main__':
 
  parser = argparse.ArgumentParser()
  parser.add_argument(
          '--directory',
          type=str,
          default='.',
          help='Directory to download data files and write the converted result'
  )
  parser.add_argument(
          '--validation_size',
          type=int,
          default=5000,
          help="""\
          Number of examples to separate from the training data for the validation
          set.\
          """
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main,argv=[sys.argv[0]] + unparsed)

 

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值