1. Preloaded data;
在Tensorflow 图中定义常量或变量来保存所有数据。
2. feeding();
python 产生的数据填充到后端;
3 reading from file:
从文件中直接读取,让队列管理器从文件中读取数据。
下面是具体实现的代码示例:
import os os.environ['TF_CPP_MIN_LOG_LEVEL']='2' import tensorflow as tf # preloading data 加载数据 x1 = tf.constant([2,3,4]) x2 = tf.constant([7,2,1]) y = tf.add(x1, x2) # feeding data by using the parameter feed_dict in sess.run() a1 = tf.placeholder(tf.int16) #构建图空间 a2 = tf.placeholder(tf.int16) b = tf.add(a1, a2) //构造图的数据 li1 = [2, 3, 4] li2 = [4, 0, 1] # generate TFRecords file with tf.Session() as sess: print(sess.run(y)) print(sess.run(b, feed_dict={a1: li1, a2: li2})) #在执行会话中填充数据到会话端
示例2(从文件中读取数据)
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import time import tensorflow as tf from tensorflow.contrib.learn.python.learn.datasets import mnist from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import mnist # basic model parameters as external flags. FLAGS = None # 构建常量用处理文件和匹配转换记录 Constants used for dealing with the files, matches convert_to_records. TRAIN_FILE = 'train.tfrecords' VALIDATAION_FILE = 'validation.tfrecords' # 编码函数 coding function def _int64_feature(value): return tf.train.Feature(int64_list = tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value])) #转换函数 (convert function),将数据集转换为tfrecords格式 def convert_to(data_set, name): images = data_set.images labels = data_set.labels num_examples = data_set.num_examples if images.shape[0] != num_examples: raise ValueError('Image size %d does not match label size %d.' % (images.shape[0],num_examples)) rows = images.shape[1] cols = images.shape[2] depth = images.shape[3] filename = os.path.join(FLAGS.directory, name+'.tfrecords') print('Writing', filename) writer = tf.python_io.TFRecordWriter(filename) for index in range(num_examples) : images_raw = images[index].tostring() example = tf.train.Example(features = tf.train.Features(feature={ 'height': _int64_feature(rows), 'width': _int64_feature(cols), 'depth': _int64_feature(depth), 'label': _int64_feature(int(labels[index])), 'image_raw': _bytes_feature(images_raw)})) writer.write(example.SerializeToString()) writer.close() # 主函数main function # Get the data def main(unused_argv): data_sets =mnist.read_data_sets(FLAGS.directory, dtype = tf.uint8, reshape = False, validation_size = FLAGS.validation_size) convert_to(data_sets.train, 'train') convert_to(data_sets.validation, 'validation') convert_to(data_sets.test, 'test') def read_and_decode(filename_queue): reader = tf.TFRecordReader() serialized_example = reader.read(filename_queue) features = tf.parse_single_example(serialized_example , features = { 'image_raw': tf.FixedLenFeature([], tf.string), 'label' : tf.FixedLenFeature([], tf.int64) }) image = tf.decode_raw(features['image_raw', tf.uint8]) image.set_shape([mnist.IMAGE_PIXELS]) image = tf.cast(image, tf.float32)*(1. / 255) - 0.5 label = tf.cast(features['label'], tf.int32) return image, label def inputs(train, batch_size, num_epochs) : if not num_epochs: num_epochs = None filename = os.path.join(FLAGS.train_dir, TRAIN_FILE if train else VALIDATAION_FILE) with tf.name_scope('input'): filename_queue = tf.train.string_input_producer( [filename], num_epochs = num_epochs) image, label = read_and_decode(filename_queue) images, sparse_labels = tf.train.shuffle_batch( [image, label], batch_size = batch_size, num_threads=2, capacity=1000+3*batch_size, min_after_dequeue= 1000) return images, sparse_labels def run_training(): with tf.Graph().as_default() : images, labels = inputs(train="True", batch_size= FLAGS.batch_size, num_epochs=FLAGS.num_epochs) logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2) loss = mnist.loss(logits, labels) train_op = mnist.training(loss, FLAGS.learning_rate) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess = tf.Session() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: step = 0 while not coord.should_stop(): start_time = time.time() loss_value = sess.run([train_op], loss) duration = time.time() - start_time if step % 100 == 0 : print('Step %d: loss = %.2f(%.3f sec' % (step, loss_value, duration)) step += 1 except tf.errors.OutOfRangeError: print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) finally: coord.request_stop() coord.join(threads) sess.close()