笔记《【TensorFlow】用TFRecord方式对数据进行读取(一)》讲述了如何生成TFRecord文件。那么如何读取生成的文件进行模型训练呢?且往下看。
本文读取TFRecord的方法大致如下:
- 构建.tfrecord文件的文件路径列表filenames
- 使用string_input_producer(filenames)方法生成一个文件名队列filename_queue
- 使用pre_image_standardization()函数对图片进行标准化
- 对label进行onehot编码,有关onehot编码,请参考《TensorFlow tf.one_hot()函数》
- 构建网络模型、选择损失函数、对梯度下降进行训练
- 启动训练
程序实现:
# coding=utf-8
import tensorflow as tf
import numpy as np
import os
DATA_DTR = "./data/"
TRAINING_SET_SIZE = 3670 # 训练集个数
BATCH_SIZE = 64 # 一次迭代多少个
IMAGE_SIZE = 224 # 图片大小、网络结构影响网络的迭代速度,图片大小影响最大。CPU环境,48*48,32*32,64*64
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# image object from protobuf
class _image_object:
def __init__(self):
self.image = tf.Variable([], dtype=tf.string)
self.height = tf.Variable([], dtype=tf.int64)
self.width = tf.Variable([], dtype=tf.int64)
self.filename = tf.Variable([], dtype=tf.string)
self.label = tf.Variable([], dtype=tf.int32)
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# 解析
features = tf.parse_single_example(serialized_example, features={
'image/encoded': tf.FixedLenFeature([], tf.string),
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/filename': tf.FixedLenFeature([], tf.string),
'image/class/label': tf.FixedLenFeature([], tf.int64), })
image_encoded = features['image/encoded']
image_raw = tf.image.decode_jpeg(image_encoded, channels=3)
image_object = _image_object()
# 根据IMAGE_SIZE对图片进行resize,并对图片进行随机裁剪
image_object.image = tf.image.resize_image_with_crop_or_pad(image_raw, IMAGE_SIZE, IMAGE_SIZE)
image_object.height = features['image/height']
image_object.width = features['image/width']
image_object.filename = features['image/filename']
image_object.label = tf.cast(features['image/class/label'], tf.int64)
return image_object
def flower_input(if_random=True, if_training=True):
if (if_training):
# filenames = [os.path.join(DATA_DTR, 'flower-train-0000%d-of-00002' % i) for i in range(0, 1)]
filenames = [os.path.join(DATA_DTR, 'train-0000%d-of-00002.tfrecord' % i) for i in range(0, 2)]
else:
# filenames = [os.path.join(DATA_DTR, 'flower-eval-0000%d-of-00002' % i) for i in range(0, 1)]
filenames = [os.path.join(DATA_DTR, 'eval-0000%d-of-00002.tfrecord' % i) for i in range(0, 2)]
for f in filenames:
if not tf.gfile.Exists(f): # 若路径/文件不存在
raise ValueError('Failed to find file:' + f)
# string_input_producer会产生一个文件名队列
filename_queue = tf.train.string_input_producer(filenames)
image_object = read_and_decode(filename_queue)
# 图片标准化。返回与image具有相同形状的标准化的图像。(线性缩放image以具有零均值和单位范数)
image = tf.image.per_image_standardization(image_object.image)
# image = image_object.image
# image = tf.image.adjust_gamma(tf.cast(image_object.image, tf.float32), gamma=1, gain=1) # Scale image to (0, 1)
label = image_object.label
filename = image_object.filename
if (if_random):
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(TRAINING_SET_SIZE * min_fraction_of_examples_in_queue)
print(
'Filling queue with %d images before starting to train.' 'This will take a few minutes.' % min_queue_examples)
num_preprocess_threads = 1
image_batch, label_batch, filename_batch = tf.train.shuffle_batch([image, label, filename],
batch_size=BATCH_SIZE,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * BATCH_SIZE,
min_after_dequeue=min_queue_examples)
return image_batch, label_batch, filename_batch
else:
image_batch, label_batch, filename_batch = tf.train.batch([image, label, filename], batch_size=BATCH_SIZE,
num_threads=1)
return image_batch, label_batch, filename_batch
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.05)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.02, shape=shape)
return tf.Variable(initial)
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
def flower_inference(image_batch):
W_conv1 = weight_variable([5, 5, 3, 32])
b_conv1 = bias_variable([32])
x_image = tf.reshape(image_batch, [-1, IMAGE_SIZE, IMAGE_SIZE, 3])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1) # 112
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2) # 56
W_conv3 = weight_variable([5, 5, 64, 128])
b_conv3 = bias_variable([128])
h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)
h_pool3 = max_pool_2x2(h_conv3) # 28
W_conv4 = weight_variable([5, 5, 128, 256])
b_conv4 = bias_variable([256])
h_conv4 = tf.nn.relu(conv2d(h_pool3, W_conv4) + b_conv4)
h_pool4 = max_pool_2x2(h_conv4) # 14
W_conv5 = weight_variable([5, 5, 256, 256])
b_conv5 = bias_variable([256])
h_conv5 = tf.nn.relu(conv2d(h_pool4, W_conv5) + b_conv5)
h_pool5 = max_pool_2x2(h_conv5) # 7
W_fc1 = weight_variable([7 * 7 * 256, 2048])
b_fc1 = bias_variable([2048])
h_pool5_flat = tf.reshape(h_pool5, [-1, 7 * 7 * 256])
h_fc1 = tf.nn.relu(tf.matmul(h_pool5_flat, W_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, 1.0)
W_fc2 = weight_variable([2048, 256])
b_fc2 = bias_variable([256])
h_fc2 = tf.nn.relu(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
W_fc3 = weight_variable([256, 64])
b_fc3 = bias_variable([64])
h_fc3 = tf.nn.relu(tf.matmul(h_fc2, W_fc3) + b_fc3)
W_fc4 = weight_variable([64, 5])
b_fc4 = bias_variable([5])
y_conv = tf.nn.softmax(tf.matmul(h_fc3, W_fc4) + b_fc4)
# y_conv = tf.matmul(h_fc3, W_fc4) + b_fc4
return y_conv
def flower_train():
image_batch_out, label_batch_out, filename_batch = flower_input(if_random=False, if_training=True)
image_batch_placeholder = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 224, 224, 3])
image_batch = tf.reshape(image_batch_out, (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3))
# 5分类
label_batch_placeholder = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 5])
label_offset = -tf.ones([BATCH_SIZE], dtype=tf.int64, name='label_batch_offset')
# 对label_batch_out做偏移,让其从0开始
label_batch_one_hot = tf.one_hot(tf.add(label_batch_out, label_offset), depth=5, on_value=1.0, off_value=0.0)
logits_out = flower_inference(image_batch_placeholder)
# loss = tf.losses.mean_squared_error(labels=label_batch_placeholder, predictions=logits_out)
# loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels=label_batch_one_hot, logits=logits_out))
loss = tf.reduce_sum(
tf.nn.softmax_cross_entropy_with_logits(labels=tf.one_hot(label_batch_out, depth=5), logits=logits_out))
train_step = tf.train.GradientDescentOptimizer(0.0005).minimize(loss)
saver = tf.train.Saver()
with tf.Session() as sess:
# Visualize the graph through tensorboard.
# file_writer = tf.summary.FileWriter('./logs', sess.graph)
sess.run(tf.global_variables_initializer())
# saver.save(sess, "")
# saver.restore(sess, './model/checkpoint-train.ckpt')
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
# 100个epoch
for i in range(TRAINING_SET_SIZE * 100):
image_out, label_out, label_batch_one_hot_out, filename_out = sess.run(
[image_batch, label_batch_out, label_batch_one_hot, filename_batch])
_, infer_out, loss_out = sess.run([train_step, logits_out, loss],
feed_dict={image_batch_placeholder: image_out,
label_batch_placeholder: label_batch_one_hot_out})
'''
print(i)
print(image_out.shape)
print('label_out: ')
print(filename_out)
print(label_out)
print(label_batch_one_hot_out)
print('infer_out: ')
print(infer_out)
print('loss: ')
print(loss_out)
'''
if (i % 100 == 0):
print(i)
print(image_out.shape)
print('label_out: ')
print(filename_out)
print(label_out)
print(label_batch_one_hot_out) # 真实值
print('infer_out: ')
print(infer_out) # 预测值
print('loss: ')
print(loss_out) # loss值
'''
if (i % 50 == 0):
saver.save(sess, './model/checkpoint-train.ckpt')
'''
coord.request_stop()
coord.join(threads)
sess.close()
def flower_eval():
image_batch_out, label_batch_out, filename_batch = flower_input(if_random=False, if_training=False)
image_batch_placeholder = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 224, 224, 3])
image_batch = tf.reshape(image_batch_out, (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, 3))
label_tensor_palceholder = tf.placeholder(tf.int64, shape=[BATCH_SIZE])
label_offset = -tf.ones([BATCH_SIZE], dtype=tf.int64, name='label_batch_offset')
label_batch = tf.add(label_batch_out, label_offset)
logits_out = tf.reshape(flower_inference(image_batch_placeholder), [BATCH_SIZE, 5])
logits_batch = tf.to_int64(tf.arg_max(logits_out, dimension=1))
correct_prediction = tf.equal(logits_batch, label_tensor_palceholder)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, './model/checkpoint-train.ckpt')
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)
accuracy_accu = 0
for i in range(29):
image_out, label_out, filename_out = sess.run([image_batch, label_batch, filename_batch])
accuracy_out, logits_batch_out = sess.run([accuracy, logits_batch],
feed_dict={image_batch_placeholder: image_out,
label_tensor_palceholder: label_out})
accuracy_accu += accuracy_out
print(i)
print(image_out.shape)
print('label_out: ')
print(filename_out)
print(label_out)
print(logits_batch_out)
print('Accuracy: ')
print(accuracy_accu / 29)
coord.request_stop()
coord.join(threads)
sess.close()
flower_train()
# flower_eval()
执行结果:
0
(64, 224, 224, 3)
label_out:
[b'18010259565_d6aae33ca7_n.jpg' b'5139977423_d413b23fde_m.jpg'
b'1306119996_ab8ae14d72_n.jpg' b'175106495_53ebdef092_n.jpg'
b'4664737020_b4c61aacd3_n.jpg' b'7184780734_3baab127c2_m.jpg'
b'8623170936_83f4152431.jpg' b'17357636476_1953c07aa4_n.jpg'
b'8021540573_c56cf9070d_n.jpg' b'488849503_63a290a8c2_m.jpg'
b'2501297526_cbd66a3f7e_m.jpg' b'16670921315_0fc48d7ab2_n.jpg'
b'8174970894_7f9a26be7e.jpg' b'8853083579_dd1dfa3188.jpg'
b'229488796_21ac6ee16d_n.jpg' b'5067864967_19928ca94c_m.jpg'
b'16645809126_613b1e3ebe_m.jpg' b'4546316433_202cc68c55.jpg'
b'14054827391_139fb54432.jpg' b'17420983523_2e32d70359.jpg'
b'4248222578_b4d5868b32.jpg' b'23894449029_bf0f34d35d_n.jpg'
b'2706736074_b0fba20b3e.jpg' b'8716513637_2ba0c4e6cd_n.jpg'
b'8713398906_28e59a225a_n.jpg' b'8437935944_aab997560a_n.jpg'
b'19438516548_bbaf350664.jpg' b'2349640101_212c275aa7.jpg'
b'16025261368_911703a536_n.jpg' b'3518608454_c3fd3c311c_m.jpg'
b'3798841385_38142ea3c6_n.jpg' b'2319777940_0cc5476b0d_n.jpg'
b'7790614422_4557928ab9_n.jpg' b'175686816_067a8cb4c5.jpg'
b'19177263840_6a316ea639.jpg' b'2995221296_a6ddaccc39.jpg'
b'23891393761_155af6402c.jpg' b'8757486380_90952c5377.jpg'
b'40410686_272bc66faf_m.jpg' b'18843967474_9cb552716b.jpg'
b'2481428401_bed64dd043.jpg' b'9617087594_ec2a9b16f6.jpg'
b'3500121696_5b6a69effb_n.jpg' b'7066602021_2647457985_m.jpg'
b'9595857626_979c45e5bf_n.jpg' b'13279526615_a3b0059bec.jpg'
b'5721768347_2ec4d2247b_n.jpg' b'3530500952_9f94fb8b9c_m.jpg'
b'10437652486_aa86c14985.jpg' b'2723995667_31f32294b4.jpg'
b'6687138903_ff6ae12758_n.jpg' b'4697206799_19dd2a3193_m.jpg'
b'3568114325_d6b1363497.jpg' b'22255608949_172d7c8d22_m.jpg'
b'19813618946_93818db7aa_m.jpg' b'164670176_9f5b9c7965.jpg'
b'9339697826_88c9c4dc50.jpg' b'17101762155_2577a28395.jpg'
b'7630511450_02d3292e90.jpg' b'4993492878_11fd4f5d12.jpg'
b'4721773235_429acdf496_n.jpg' b'3472437817_7902b3d984_n.jpg'
b'142235914_5419ff8a4a.jpg' b'5948835387_5a98d39eff_m.jpg']
[2 4 1 1 4 2 5 1 1 3 3 3 4 3 3 4 5 5 5 2 3 4 4 2 5 3 2 1 1 2 4 2 1 5 1 2 2
5 4 4 2 2 3 1 1 3 3 2 2 4 3 1 4 4 1 4 4 1 1 1 2 2 5 1]
[[0. 1. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 0. 1.]
[0. 0. 0. 0. 1.]
[0. 0. 0. 0. 1.]
[0. 1. 0. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 1. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 0. 0. 1.]
[0. 0. 1. 0. 0.]
[0. 1. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 1. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 0. 0. 1.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 1. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 1. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 0. 1. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 1. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 1. 0.]
[1. 0. 0. 0. 0.]
[0. 0. 0. 1. 0.]
[0. 0. 0. 1. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[1. 0. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0.]]
infer_out:
[[0.13056415 0.04913541 0.32594264 0.03438465 0.45997307]
[0.26384547 0.2623756 0.2435487 0.04329655 0.1869336 ]
[0.0706915 0.12705286 0.08502517 0.64274526 0.07448522]
[0.11050481 0.06549674 0.1426018 0.34517446 0.33622214]
[0.25743067 0.07589907 0.35129794 0.09244049 0.2229318 ]
[0.1374493 0.15353125 0.21229312 0.2894446 0.20728172]
[0.09434555 0.02388783 0.7310694 0.07601975 0.07467746]
[0.0838695 0.0934897 0.33276933 0.08913088 0.40074065]
[0.04950906 0.1499536 0.60577244 0.06054341 0.13422155]
[0.14419325 0.1415817 0.48861662 0.05337767 0.17223081]
[0.15547645 0.13317487 0.30910924 0.06604663 0.33619282]
[0.1821311 0.08428234 0.30044657 0.21865055 0.21448937]
[0.20056832 0.02380027 0.5013732 0.03798183 0.23627648]
[0.14194377 0.08518288 0.39406604 0.11473632 0.26407102]
[0.48186043 0.03133626 0.20779409 0.03820071 0.24080853]
[0.13544849 0.12823503 0.22201277 0.06939054 0.44491315]
[0.1127345 0.13301061 0.4368849 0.13075082 0.18661909]
[0.122599 0.0050048 0.41727737 0.06205699 0.39306185]
[0.12478087 0.05266788 0.5534555 0.11985254 0.14924324]
[0.21396564 0.09996541 0.17408894 0.18133828 0.33064175]
[0.13904135 0.03576617 0.1610776 0.39143124 0.2726836 ]
[0.16825004 0.044227 0.04287214 0.4528682 0.29178265]
[0.3567176 0.03447685 0.1967859 0.05021394 0.36180568]
[0.14819665 0.03879297 0.11420825 0.07196447 0.6268377 ]
[0.06664324 0.09225077 0.5089801 0.05105572 0.2810702 ]
[0.47721437 0.08342645 0.20137496 0.08103408 0.15695006]
[0.24899334 0.0937962 0.40340105 0.13583842 0.11797093]
[0.1910153 0.1018641 0.33720532 0.04594139 0.32397395]
[0.37828308 0.04301307 0.09033397 0.0235413 0.46482858]
[0.06822239 0.08883941 0.6621396 0.09125091 0.08954761]
[0.14358594 0.09292904 0.6177641 0.03208194 0.11363894]
[0.1044371 0.07252962 0.5346857 0.08311398 0.20523357]
[0.39982766 0.08607291 0.25110722 0.02854534 0.23444691]
[0.1001133 0.02308131 0.5584598 0.05035937 0.2679862 ]
[0.44755536 0.03560201 0.22415678 0.14573278 0.1469531 ]
[0.05461171 0.03521081 0.57107073 0.03843536 0.30067134]
[0.20951813 0.10207474 0.2779998 0.11641366 0.29399362]
[0.04378606 0.027622 0.7215292 0.08199167 0.1250711 ]
[0.16532472 0.08239221 0.08965182 0.34766296 0.31496826]
[0.45645338 0.08051163 0.26655343 0.12141783 0.07506368]
[0.1287272 0.08703651 0.5341651 0.13985255 0.11021858]
[0.13924025 0.06041571 0.2111614 0.04206975 0.5471129 ]
[0.06036361 0.13316424 0.56349164 0.07972634 0.1632541 ]
[0.1646403 0.14190002 0.30503014 0.15408823 0.23434131]
[0.12360985 0.0701086 0.5298012 0.15381391 0.1226664 ]
[0.19767626 0.12351054 0.27424613 0.11784815 0.28671893]
[0.04888719 0.09419045 0.28086978 0.13631783 0.4397347 ]
[0.08643341 0.04461662 0.1592851 0.15064448 0.55902034]
[0.28541085 0.02239579 0.13165314 0.10752396 0.4530163 ]
[0.16877909 0.0521294 0.46971634 0.02459331 0.28478178]
[0.16809784 0.17493002 0.29431188 0.10929338 0.25336695]
[0.23965484 0.06915538 0.4016183 0.10721285 0.18235867]
[0.16836813 0.04364016 0.4770574 0.03853995 0.27239436]
[0.20559838 0.12010937 0.07704528 0.0739807 0.52326626]
[0.09065755 0.11342864 0.6540009 0.04456346 0.09734945]
[0.0948393 0.07684919 0.45677623 0.14801382 0.22352141]
[0.13934647 0.0108922 0.5130212 0.06993146 0.26680872]
[0.2995163 0.04558311 0.0994138 0.01326568 0.54222107]
[0.1164759 0.06499377 0.2876498 0.03181984 0.4990606 ]
[0.03242943 0.00875211 0.899797 0.02933092 0.02969051]
[0.19746183 0.05923576 0.56459033 0.08405775 0.09465437]
[0.08568227 0.09119201 0.44635725 0.10182896 0.2749395 ]
[0.08776123 0.0846834 0.649502 0.04741209 0.13064125]
[0.1287426 0.09562506 0.22470206 0.32401603 0.22691426]]
loss:
83.281944