谢谢作者指导,感谢作者提供这么流畅的代码。@yangxue0827
我的训练结果有点雷人。可能是我得训练方式不对。
我只针对airplane进行训练,但是这里的结果不是很优美。如下:
我在voc_eval.py添加了一个参数FAST_RCNN_IOU_MAP。
当FAST_RCNN_IOU_MAP=.01,会给一个安慰结果。
当FAST_RCNN_IOU_MAP=.5,无语了
预测图:
我的config:
这里会不会还有其他问题?:
#!/usr/bin/python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import argparse
import os
import pickle
import sys
import time
import io
import numpy as np
import tensorflow as tf
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
sys.stdout = io.TextIOWrapper(sys.stdout.buffer,encoding='iso-8859-1')
from data.io.image_preprocess import short_side_resize_for_inference_data
from data.io.read_tfrecord import next_batch
from data.io.divide_data import mkdir
from help_utils.tools import *
from libs.fast_rcnn import build_fast_rcnn
from libs.label_name_dict.label_dict import *
from libs.box_utils import draw_box_in_img
from libs.networks.network_factory import get_flags_byname, get_network_byname
from libs.rpn import build_rpn
from libs.val_libs import voc_eval
from tools import restore_model
import numpy as np
import cv2
FLAGS = get_flags_byname(cfgs.NET_NAME)
def eval_dict_convert(args, draw_imgs=True):
with tf.Graph().as_default():
'''
这里就是读取图片的操作,
'''
img_name_batch, img_batch, gtboxes_and_label_batch, num_objects_batch = \
next_batch(dataset_name=cfgs.DATASET_NAME,
batch_size=cfgs.BATCH_SIZE,
shortside_len=cfgs.SHORT_SIDE_LEN,
is_training=False)
# ***********************************************************************************************
# * share net *
# ***********************************************************************************************
_, share_net = get_network_byname(net_name=cfgs.NET_NAME,
inputs=img_batch,
num_classes=None,
is_training=True,
output_stride=None,
global_pool=False,
spatial_squeeze=False)
# ***********************************************************************************************
# * RPN *
# ***********************************************************************************************
rpn = build_rpn.RPN(net_name=cfgs.NET_NAME,
inputs=img_batch,
gtboxes_and_label=None,
is_training=False,
share_head=True,
share_net=share_net,
stride=cfgs.STRIDE,
anchor_ratios=cfgs.ANCHOR_RATIOS,
anchor_scales=cfgs.ANCHOR_SCALES,
scale_factors=cfgs.SCALE_FACTORS,
base_anchor_size_list=cfgs.BASE_ANCHOR_SIZE_LIST, # P2, P3, P4, P5, P6
level=cfgs.LEVEL,
top_k_nms=cfgs.RPN_TOP_K_NMS,
rpn_nms_iou_threshold=cfgs.RPN_NMS_IOU_THRESHOLD,
max_proposals_num=cfgs.MAX_PROPOSAL_NUM,
rpn_iou_positive_threshold=cfgs.RPN_IOU_POSITIVE_THRESHOLD,
rpn_iou_negative_threshold=cfgs.RPN_IOU_NEGATIVE_THRESHOLD,
rpn_mini_batch_size=cfgs.RPN_MINIBATCH_SIZE,
rpn_positives_ratio=cfgs.RPN_POSITIVE_RATE,
remove_outside_anchors=False, # whether remove anchors outside
rpn_weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME])
# rpn predict proposals
rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals() # rpn_score shape: [300, ]
# ***********************************************************************************************
# * Fast RCNN *
# ***********************************************************************************************
fast_rcnn = build_fast_rcnn.FastRCNN(img_batch=img_batch,
feature_pyramid=rpn.feature_pyramid,
rpn_proposals_boxes=rpn_proposals_boxes,
rpn_proposals_scores=rpn_proposals_scores,
img_shape=tf.shape(img_batch),
roi_size=cfgs.ROI_SIZE,
scale_factors=cfgs.SCALE_FACTORS,
roi_pool_kernel_size=cfgs.ROI_POOL_KERNEL_SIZE,
gtboxes_and_label=None,
fast_rcnn_nms_iou_threshold=cfgs.FAST_RCNN_NMS_IOU_THRESHOLD,
fast_rcnn_maximum_boxes_per_img=100,
fast_rcnn_nms_max_boxes_per_class=cfgs.FAST_RCNN_NMS_MAX_BOXES_PER_CLASS,
show_detections_score_threshold=cfgs.FINAL_SCORE_THRESHOLD, # show detections which score >= 0.6
num_classes=cfgs.CLASS_NUM,
fast_rcnn_minibatch_size=cfgs.FAST_RCNN_MINIBATCH_SIZE,
fast_rcnn_positives_ratio=cfgs.FAST_RCNN_POSITIVE_RATE,
fast_rcnn_positives_iou_threshold=cfgs.FAST_RCNN_IOU_POSITIVE_THRESHOLD,
use_dropout=False,
weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME],
is_training=False,
level=cfgs.LEVEL)
fast_rcnn_decode_boxes, fast_rcnn_score, num_of_objects, detection_category = \
fast_rcnn.fast_rcnn_predict()
# train
init_op = tf.group(
tf.global_variables_initializer(),
tf.local_variables_initializer()
)
restorer, restore_ckpt = restore_model.get_restorer(checkpoint_path=args.weights)
config = tf.ConfigProto()
# config.gpu_options.per_process_gpu_memory_fraction = 0.5
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(init_op)
if not restorer is None:
restorer.restore(sess, restore_ckpt)
print('restore model')
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess, coord)
img_name_batchs = []
all_boxes = []
for i in range(args.img_num):
start = time.time()
_img_name_batch, _img_batch, _gtboxes_and_label_batch, _fast_rcnn_decode_boxes, \
_fast_rcnn_score, _detection_category \
= sess.run([img_name_batch, img_batch, gtboxes_and_label_batch, fast_rcnn_decode_boxes,
fast_rcnn_score, detection_category])
end = time.time()
raw_img = cv2.imread(os.path.join(args.eval_imgs, _img_name_batch[0].decode('utf-8','ignore')))
raw_h, raw_w = raw_img.shape[0], raw_img.shape[1]
ymin, xmin, ymax, xmax = _fast_rcnn_decode_boxes[:, 0], _fast_rcnn_decode_boxes[:, 1], \
_fast_rcnn_decode_boxes[:, 2], _fast_rcnn_decode_boxes[:, 3]
resized_h, resized_w = _img_batch.shape[1], _img_batch.shape[2]
xmin = xmin * raw_w / resized_w
xmax = xmax * raw_w / resized_w
ymin = ymin * raw_h / resized_h
ymax = ymax * raw_h / resized_h
if len(_detection_category) != 0:
if draw_imgs:
show_indices = _fast_rcnn_score >= cfgs.FINAL_SCORE_THRESHOLD
show_scores = _fast_rcnn_score[show_indices]
show_boxes = _fast_rcnn_decode_boxes[show_indices]
show_categories = _detection_category[show_indices]
pre_ymin, pre_xmin, pre_ymax, pre_xmax = _fast_rcnn_decode_boxes[:, 0], _fast_rcnn_decode_boxes[:, 1], \
_fast_rcnn_decode_boxes[:, 2], _fast_rcnn_decode_boxes[:, 3]
pre_box = np.stack([pre_xmin, pre_ymin, pre_xmax, pre_ymax], axis=1)
final_detections = draw_box_in_img.draw_boxes_with_label_and_scores(np.squeeze(_img_batch, 0),
boxes=pre_box,
labels=show_categories,
scores=show_scores)
gt_ymin, gt_xmin, gt_ymax, gt_xmax = _gtboxes_and_label_batch[0][:, 0], _gtboxes_and_label_batch[0][:, 1], \
_gtboxes_and_label_batch[0][:, 2], _gtboxes_and_label_batch[0][:, 3],
gt_box = np.stack([gt_xmin, gt_ymin, gt_xmax, gt_ymax], axis=1)
final_detectionsgt = draw_box_in_img.draw_boxes_with_label_and_scores(np.squeeze(_img_batch, 0),
boxes=gt_box,
labels=_gtboxes_and_label_batch[0][:,-1],
scores=[1]*len(_gtboxes_and_label_batch))
cv2.imwrite('./eval/images/' + _img_name_batch[0].decode('utf-8','ignore'),
final_detections[:, :, ::-1])
cv2.imwrite('./eval/images/' + _img_name_batch[0].decode('utf-8', 'ignore').split('.jpg')[0] + 'gt.jpg',
final_detectionsgt[:, :, ::-1])
boxes = np.transpose(np.stack([xmin, ymin, xmax, ymax]))
dets = np.hstack((_detection_category.reshape(-1, 1),
_fast_rcnn_score.reshape(-1, 1),
boxes))
img_name_batchs.append(_img_name_batch[0])
all_boxes.append(dets)
view_bar('{} image cost {}s'.format(
str(_img_name_batch[0]), (end - start)), i + 1, args.img_num)
fw1 = open( 'detections.pkl', 'wb')
pickle.dump(all_boxes, fw1)
fw1.close()
coord.request_stop()
coord.join(threads)
return img_name_batchs
def eval(args):
print('Called with args:')
img_name_batchs = eval_dict_convert(args)
test_imgname_list = [img.decode('utf-8','ignore') for img in img_name_batchs]
with open('detections.pkl','rb') as f:
all_boxes = pickle.load(f,encoding='iso-8859-1')
voc_eval.voc_evaluate_detections(all_boxes=all_boxes,
test_annotation_path=args.annotation_dir,
test_imgid_list=test_imgname_list)
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Evaluate a trained FPN model')
parser.add_argument('--eval_imgs', dest='eval_imgs',
help='evaluate imgs dir ',
default='../data/io/VOCdevkit_test/JPEGImages', type=str)
parser.add_argument('--annotation_dir', dest='annotation_dir',
help='the dir save annotations',
default='../data/io/VOCdevkit_test/Annotations', type=str)
parser.add_argument('--weights', dest='weights',
help='model path',
default='../output/airplane/res101_trained_weights/v3_airplane/airplane_115500model.ckpt',
type=str)
parser.add_argument('--img_num', dest='img_num',
help='image numbers',
default=300, type=int)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
eval(args)