【目标检测算法实现系列】Keras实现Faster R-CNN算法(一)
【目标检测算法实现系列】Keras实现Faster R-CNN算法(二)
【目标检测算法实现系列】Keras实现Faster R-CNN算法(三)
讲过上面几篇文章,实现了Fater RCNN中的所有模块,这次来具体看下训练和测试过程
一、模型训练
from keras_faster_rcnn import config, data_generators, data_augment, losses
from keras_faster_rcnn import net_model, roi_helper, RoiPoolingConv, voc_data_parser
from keras.optimizers import Adam, SGD, RMSprop
from keras.utils import generic_utils
from keras.layers import Input
from keras.models import Model
from keras import backend as K
import numpy as np
import time
import pprint
import pickle
#获取原始数据集
all_imgs, classes_count, class_mapping = voc_data_parser.get_data("data")
if 'bg' not in classes_count:
classes_count['bg'] = 0
class_mapping['bg'] = len(class_mapping)
pprint.pprint(classes_count)
print('类别数 (包含背景) = {}'.format(len(classes_count)))
num_imgs = len(all_imgs)
train_imgs = [s for s in all_imgs if s['imageset'] == 'train'] #训练集
val_imgs = [s for s in all_imgs if s['imageset'] == 'val'] #验证集
test_imgs = [s for s in all_imgs if s['imageset'] == 'test'] #测试集
print('训练样本个数 {}'.format(len(train_imgs)))
print('验证样本个数 {}'.format(len(val_imgs)))
print('测试样本个数 {}'.format(len(test_imgs)))
C = config.Config() #相关配置信息
C.class_mapping = class_mapping
config_output_filename = "config/config.pickle"
with open(config_output_filename, "wb") as config_f:
pickle.dump(C, config_f)
print('Config has been written to {}, and can be loaded when testing to ensure correct results'.format(
config_output_filename))
#生成用于RPN网络训练数据集的迭代器
data_gen_train = data_generators.get_anchor_data_gt(train_imgs, classes_count, C, mode='train')
data_gen_val = data_generators.get_anchor_data_gt(val_imgs, classes_count, C, mode='val')
data_gen_test = data_generators.get_anchor_data_gt(test_imgs, classes_count, C, mode='val')
img_input = Input(shape=(None, None, 3)) #网络模型最开始的输入
roi_input = Input(shape=(None, 4)) #roi模块的输入
'''
model_rpn : 输入:图片数据; 输出:对应RPN网络中分类层和回归层的两个输出
model_classifier: 输入: 图片数据和选取出来的ROI数据; 输出: 最终分类层输出和回归层输出
'''
# 用来进行特征提取的基础网络 VGG16
shared_layers = net_model.base_net_vgg(img_input)
# RPN网络
num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
rpn = net_model.rpn_net(shared_layers, num_anchors)
# 最后的检测网络(包含ROI池化层 和 全连接层)
classifier = net_model.roi_classifier(shared_layers, roi_input, C.num_rois, nb_classes=len(classes_count))
model_rpn = Model(img_input, rpn[:2])
model_classifier = Model([img_input, roi_input], classifier)
#这是一个同时包含RPN和分类器的模型,用于为模型加载/保存权重
model_all = Model([img_input, roi_input], rpn[:2] + classifier)
try:
print('loading weights from {}'.format(C.model_path))
model_rpn.load_weights(C.model_path, by_name=True)
model_classifier.load_weights(C.model_path, by_name=True)
except: