最后是第四部分,作为最上层的train_imagenet.py,从顶层控制神经网络的运作。
import skimage.io # bug. need to import this before tensorflow
import skimage.transform # bug. need to import this before tensorflow
from resnet_train import train
import tensorflow as tf
import time
import os
import sys
import re
import numpy as np
import pdb
from synset import *
from image_processing import image_preprocessing
from resnet import inference
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_dir', '/home/modric/Downloads/train',
'imagenet dir') #图像存放路径
#获取所有图像的名称
def file_list(data_dir):
dir_txt = data_dir + ".txt"
filenames = []
with open(dir_txt, 'r') as f:
for line in f:
if line[0] == '.': continue
line = line.rstrip()
fn = os.path.join(data_dir, line)
filenames.append(fn)
return filenames
#装载数据
def load_data(data_dir):
data = []
i = 0
print "listing files in", data_dir
start_time = time.time()
files = file_list(data_dir)
duration = time.time() - start_time
print "took %f sec" % duration #list files所花的时间
for img_fn in files:
ext = os.path.splitext(img_fn)[1]
if ext != '.jpg': continue #若不是jpg格式则不读入
label_name = re.search(r'(n\d+)', img_fn).group(1)
fn = os.path.join(data_dir, img_fn)
label_index = synset_map[label_name]["index"] #从synset_map中读取label_index
data.append({
"filename": fn,
"label_name": label_name,
"label_index": label_index,
"desc": synset[label_index],
})
return data
#distort图像
def distorted_inputs():
data = load_data(FLAGS.data_dir) #调用装载函数
filenames = [ d['filename'] for d in data ]
#print(filenames)
label_indexes = [ d['label_index'] for d in data ]
filename, label_index = tf.train.slice_input_producer([filenames, label_indexes], shuffle=True) #将队列的QueueRunner添加到当前图形的QUEUE_RUNNER集合中
num_preprocess_threads = 1 #处理的线程个数
images_and_labels = []
for thread_id in range(num_preprocess_threads):
image_buffer = tf.read_file(filename) #将file读到buffer中
bbox = []
train = True
image = image_preprocessing(image_buffer, bbox, train, thread_id) #对图像进行distort
image = tf.image.resize_images(image, [300, 300], )
image = tf.image.resize_image_with_crop_or_pad(image, 224, 224) #因为我输入的图像不是224*224的,所以有了这两步操作
images_and_labels.append([image, label_index])
images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=FLAGS.batch_size,
capacity=2 * num_preprocess_threads * FLAGS.batch_size) #制作batch
height = FLAGS.input_size
width = FLAGS.input_size
depth = 3
images = tf.cast(images, tf.float32)
images = tf.reshape(images, shape=[FLAGS.batch_size, height, width, depth])
return images, tf.reshape(label_index_batch, [FLAGS.batch_size])
def main(_):
images, labels = distorted_inputs()
is_training = tf.placeholder('bool', [], name='is_training')
logits = inference(images,
num_classes=1000,
is_training=True,
bottleneck=False,
num_blocks=[2, 2, 2, 2]) #调用神经网络
train(is_training,logits, images, labels) #训练
if __name__ == '__main__':
tf.app.run()
通过对resnet的学习可以对整个神经网络的组成结构和主要操作有一个大致的了解,有助于下一步学习。