R-CNN原理详解与代码超详细讲解(四)–train_predict代码讲解
config代码
IMAGE_WIDTH = 227
IMAGE_HEIGHT = 227
IMAGE_CHANNEL = 3
CLASS_NUMBER = 3
ALEX_NET_MAT_FILE_PATH = "C:/Users/user/Desktop/05_rcnn/AlexNet预加载模型/imagenet-caffe-alex.mat"
ORIGINAL_FINE_TUNE_DATA_FILE_PATH = r'C:\Users\user\Desktop\05_rcnn\img_datas\fine_tune_list.txt'
TRAIN_DATA_FILE_PATH = './datas/traning_data.npy'
TRAIN_LABEL_DICT_FILE_PATH = './datas/label_dict.pkl'
FINE_TUNE_SUMMARY_WRITER_LOG_DIR = './output/graph/fine_tune'
FINE_TUNE_CHECKPOINT_DIR = './output/models/fine_tune'
FINE_TUNE_CHECKPOINT_FILENAME = 'models.ckpt'
FINE_TUNE_MAX_STEP = 10000
FINE_TUNE_SUMMARY_STEP = 10
FINE_TUNE_CHECKPOINT_STEP = 50
FINE_TUNE_INITIAL_LEARNING_RATE = 0.001
FINE_TUNE_DECAY_STEPS = 1000
FINE_TUNE_DECAY_RATE = 0.99
FINE_TUNE_IOU_THRESHOLD = 0.5
FINE_TUNE_POSITIVE_BATCH_SIZE = 8
FINE_TUNE_NEGATIVE_BATCH_SIZE = 24
TRAIN_SVM_HIGHER_FEATURES_DATA_FILE_PATH = "./datas/svm/higher_features_{}.npy"
SVM_CHECKPOINT_FILE_PATH = "./output/models/svm/model_{}.pkl"
train_predict代码
class SolverType(object):
TRAIN_FINE_TUNE_MODEL = 0
GENERATE_TRAIN_SVM_FEATURES = 1
TRAIN_SVM_MODEL = 2
GENERATE_TRAIN_REGRESSION_FEATURES = 3
TRAIN_REGRESSION_MODEL = 4
PREDICT_BOUNDING_BOX = 5
PREDICT_BOUNDING_BOX_STEP1 = 6
PREDICT_BOUNDING_BOX_STEP2 = 7
PREDICT_BOUNDING_BOX_STEP3 = 8
PREDICT_BOUNDING_BOX_STEP4 = 9
class Solver(object):
def __init__(self, solver_type):
self.is_training = False
self.is_svm = False
if SolverType.TRAIN_FINE_TUNE_MODEL == solver_type:
with tf.Graph().as_default():
print("进行Fine Tune模型训练操作....")
self.is_training = True
self.net = AlexNet(alexnet_mat_file_path=cfg.ALEX_NET_MAT_FILE_PATH,
is_training=self.is_training)
self.data_loader = FlowerDataLoader()
self.__set_fine_tune_config()
check_directory(self.summary_writer_log_dir)
check_directory(self.checkpoint_dir)
self.__get_or_create_global_step()
self.__create_tf_train_op()
self.__create_tf_saver()
self.__create_tf_summary()
self.__create_tf_session_and_initial()
self.run = self.__fine_tune_train
elif SolverType.GENERATE_TRAIN_SVM_FEATURES == solver_type:
with tf.Graph().as_default():
print("生成SVM训练用高阶特征属性,并持久化磁盘文件....")
self.is_training = False
self.is_svm = True
self.net = AlexNet(alexnet_mat_file_path=cfg.ALEX_NET_MAT_FILE_PATH,
is_training=self.is_training, is_svm=self.is_svm)
self.data_loader = FlowerDataLoader()
self.__set_fine_tune_config()
check_directory(self.summary_writer_log_dir)
check_directory(self.checkpoint_dir, created=False, error=True)
self.__get_or_create_global_step()
self.__create_tf_saver()
self.__create_tf_summary()
self.__create_tf_session_and_initial()
self.run = self.__persistent_svm_higher_features
elif SolverType.TRAIN_SVM_MODEL == solver_type:
print("进行SVM模型训练操作....")
self.is_svm = True
self.is_training = True
self.net = SVMModel(is_training=self.is_training)
self.run = self.__svm_train
def __set_fine_tune_config(self):
self.initial_learning_rate = cfg.FINE_TUNE_INITIAL_LEARNING_RATE
self.decay_steps = cfg.FINE_TUNE_DECAY_STEPS
self.decay_rate = cfg.FINE_TUNE_DECAY_RATE
self.summary_writer_log_dir = cfg.FINE_TUNE_SUMMARY_WRITER_LOG_DIR
self.checkpoint_dir = cfg.FINE_TUNE_CHECKPOINT_DIR
self.checkpoint_path = os.path.join(self.checkpoint_dir, cfg.FINE_TUNE_CHECKPOINT_FILENAME)
self.max_steps = cfg.FINE_TUNE_MAX_STEP
self.summary_step = cfg.FINE_TUNE_SUMMARY_STEP
self.checkpoint_step = cfg.FINE_TUNE_CHECKPOINT_STEP
def __get_or_create_global_step(self):
self.global_step = tf.train.get_or_create_global_step()
def __create_tf_train_op(self):
if self.is_training:
with tf.variable_scope("train"):
self.learning_rate = tf.train.exponential_decay(
learning_rate=self.initial_learning_rate,
global_step=self.global_step,
decay_steps=self.decay_steps,
decay_rate=self.decay_rate,
name='learning_rate')
tf.summary.scalar('learning_rate', self.learning_rate)
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) \
.minimize(self.net.total_loss, global_step=self.global_step)
self.ema = tf.train.ExponentialMovingAverage(0.99)
with tf.control_dependencies([self.optimizer]):
self.train_op = self.ema.apply(tf.trainable_variables())
def __create_tf_saver(self):
self.saver = tf.train.Saver()
def __create_tf_summary(self):
self.summary = tf.summary.merge_all()
self.writer = tf.summary.FileWriter(self.summary_writer_log_dir, graph=tf.get_default_graph())
def __create_tf_session_and_initial(self):
self.session = tf.Session()
self.session.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
print("进行模型恢复操作...")
self.saver.restore(self.session, ckpt.model_checkpoint_path)
self.saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
def __svm_train(self):
"""
进行SVM模型训练
:return:
"""
self.net.train()
def __fine_tune_train(self):
if not self.is_training:
raise Exception("Train method request set 'is_training' parameter is True.")
start_step = self.session.run(self.global_step)
end_step = start_step + self.max_steps
for step in range(start_step, end_step):
images, labels = self.data_loader.get_fine_tune_batch()
feed_dict = {self.net.input_data: images, self.net.label: labels}
if step % self.summary_step == 0:
summary_, loss_, accuracy_, _ = self.session.run(
[self.summary, self.net.total_loss, self.net.accuracy, self.train_op],
feed_dict=feed_dict)
self.writer.add_summary(summary_, global_step=step)
print("Training Step:{}, Loss:{}, Accuracy:{}".format(step, loss_, accuracy_))
else:
self.session.run(self.train_op, feed_dict=feed_dict)
if step % self.checkpoint_step == 0:
print("Saving model to {}".format(self.checkpoint_dir))
self.saver.save(sess=self.session, save_path=self.checkpoint_path, global_step=step)
def __fine_tune_predict(self, images):
"""
运行,得到Fine Tune模型的返回值
:param images:
:return:
"""
return self.session.run(self.net.logits, feed_dict={self.net.input_data: images})
def __persistent_svm_higher_features(self):
"""
持久化用于svm模型训练的高阶特征数据
在svm模型训练中,是针对每个类别训练一个svm模型,所有在这里需要对于每个类别产生一个训练数据文件
:return:
"""
check_directory(cfg.TRAIN_LABEL_DICT_FILE_PATH, created=False, error=True)
class_name_2_index_dict = pickle.load(open(cfg.TRAIN_LABEL_DICT_FILE_PATH, 'rb'))
for class_name, index in class_name_2_index_dict.items():
print("Start process type '{}/{}' datas...".format(index, class_name))
X = None
Y = None
images, labels = self.data_loader.get_structure_higher_features(label=index)
if images is None or labels is None:
print("没办法获取标签:{}对应的数据集!!!".format(index))
continue
print(np.shape(images), np.shape(labels))
higher_features = self.__fine_tune_predict(images)
X = higher_features
Y = labels
print("Final Feature Attribute Structure:{} - {}".format(np.shape(X), np.shape(Y)))
print("Number of occurrences of each category:{}".format(collections.Counter(Y)))
data = np.concatenate((X, np.reshape(Y, (-1, 1))), axis=1)
svm_higher_features_save_path = cfg.TRAIN_SVM_HIGHER_FEATURES_DATA_FILE_PATH.format(index)
check_directory(os.path.dirname(svm_higher_features_save_path))
np.save(svm_higher_features_save_path, data)
def run_solver():
flag = 2
if flag == 0:
solver = Solver(SolverType.TRAIN_FINE_TUNE_MODEL)
elif flag == 1:
solver = Solver(SolverType.GENERATE_TRAIN_SVM_FEATURES)
elif flag == 2:
solver = Solver(SolverType.TRAIN_SVM_MODEL)
solver.run()
if __name__ == '__main__':
run_solver()