【目标检测算法实现系列】Keras实现Faster R-CNN算法(四)

【目标检测算法实现系列】Keras实现Faster R-CNN算法(一)

【目标检测算法实现系列】Keras实现Faster R-CNN算法(二)

【目标检测算法实现系列】Keras实现Faster R-CNN算法(三)

讲过上面几篇文章,实现了Fater RCNN中的所有模块,这次来具体看下训练和测试过程

一、模型训练

from keras_faster_rcnn import config, data_generators, data_augment, losses
from keras_faster_rcnn import  net_model, roi_helper, RoiPoolingConv, voc_data_parser
from keras.optimizers import Adam, SGD, RMSprop
from keras.utils import generic_utils
from keras.layers import Input
from keras.models import Model
from keras import backend as K
import numpy as np
import time
import pprint
import pickle
#获取原始数据集
all_imgs, classes_count, class_mapping = voc_data_parser.get_data("data")
if 'bg' not in classes_count:
    classes_count['bg'] = 0
    class_mapping['bg'] = len(class_mapping)

pprint.pprint(classes_count)
print('类别数 (包含背景) = {}'.format(len(classes_count)))

num_imgs = len(all_imgs)

train_imgs = [s for s in all_imgs if s['imageset'] == 'train']  #训练集
val_imgs = [s for s in all_imgs if s['imageset'] == 'val']  #验证集
test_imgs = [s for s in all_imgs if s['imageset'] == 'test']  #测试集
print('训练样本个数 {}'.format(len(train_imgs)))
print('验证样本个数 {}'.format(len(val_imgs)))
print('测试样本个数 {}'.format(len(test_imgs)))

C = config.Config()  #相关配置信息
C.class_mapping = class_mapping
config_output_filename = "config/config.pickle"
with open(config_output_filename, "wb") as config_f:
    pickle.dump(C, config_f)
    print('Config has been written to {}, and can be loaded when testing to ensure correct results'.format(
        config_output_filename))


#生成用于RPN网络训练数据集的迭代器
data_gen_train = data_generators.get_anchor_data_gt(train_imgs, classes_count, C, mode='train')
data_gen_val = data_generators.get_anchor_data_gt(val_imgs, classes_count, C, mode='val')
data_gen_test = data_generators.get_anchor_data_gt(test_imgs, classes_count, C, mode='val')

img_input = Input(shape=(None, None, 3))  #网络模型最开始的输入
roi_input = Input(shape=(None, 4))   #roi模块的输入

'''
model_rpn : 输入:图片数据;  输出:对应RPN网络中分类层和回归层的两个输出
model_classifier:  输入: 图片数据和选取出来的ROI数据;   输出: 最终分类层输出和回归层输出
'''
# 用来进行特征提取的基础网络 VGG16
shared_layers = net_model.base_net_vgg(img_input)
# RPN网络
num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
rpn = net_model.rpn_net(shared_layers, num_anchors)
# 最后的检测网络(包含ROI池化层 和 全连接层)
classifier = net_model.roi_classifier(shared_layers, roi_input, C.num_rois, nb_classes=len(classes_count))

model_rpn = Model(img_input, rpn[:2])
model_classifier = Model([img_input, roi_input], classifier)

#这是一个同时包含RPN和分类器的模型,用于为模型加载/保存权重
model_all = Model([img_input, roi_input], rpn[:2] + classifier)

try:
    print('loading weights from {}'.format(C.model_path))
    model_rpn.load_weights(C.model_path, by_name=True)
    model_classifier.load_weights(C.model_path, by_name=True)

except:
  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值