如何查询mysql的next_val_您好,请问训练过后如何在benchmark数据库上检测训练结果的准确率?...

该博客主要展示了如何使用深度学习模型Fast R-CNN进行目标检测,并对VOC数据集进行评估。作者提供了代码片段,用于加载模型、进行预测并可视化检测结果。在不同IoU阈值下,模型的表现有所差异,可能存在优化空间。
摘要由CSDN通过智能技术生成

谢谢作者指导,感谢作者提供这么流畅的代码。@yangxue0827

我的训练结果有点雷人。可能是我得训练方式不对。

我只针对airplane进行训练,但是这里的结果不是很优美。如下:

我在voc_eval.py添加了一个参数FAST_RCNN_IOU_MAP。

bba04e7b392dd77d8d55a24745e42e08.png

当FAST_RCNN_IOU_MAP=.01,会给一个安慰结果。

75ae4c555f5028ddedf09419868cc888.png

当FAST_RCNN_IOU_MAP=.5,无语了

bbd3b01ac373c88fbd9725ea00d1dd8d.png

预测图:

131e0b24f1d0d4c3497a6ef8fc396bb9.png

我的config:

434ae6b0baa0da98194f92eacc804475.png

这里会不会还有其他问题?:

#!/usr/bin/python

# -*- coding: utf-8 -*-

from __future__ import absolute_import, division, print_function

import argparse

import os

import pickle

import sys

import time

import io

import numpy as np

import tensorflow as tf

sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

sys.stdout = io.TextIOWrapper(sys.stdout.buffer,encoding='iso-8859-1')

from data.io.image_preprocess import short_side_resize_for_inference_data

from data.io.read_tfrecord import next_batch

from data.io.divide_data import mkdir

from help_utils.tools import *

from libs.fast_rcnn import build_fast_rcnn

from libs.label_name_dict.label_dict import *

from libs.box_utils import draw_box_in_img

from libs.networks.network_factory import get_flags_byname, get_network_byname

from libs.rpn import build_rpn

from libs.val_libs import voc_eval

from tools import restore_model

import numpy as np

import cv2

FLAGS = get_flags_byname(cfgs.NET_NAME)

def eval_dict_convert(args, draw_imgs=True):

with tf.Graph().as_default():

'''

这里就是读取图片的操作,

'''

img_name_batch, img_batch, gtboxes_and_label_batch, num_objects_batch = \

next_batch(dataset_name=cfgs.DATASET_NAME,

batch_size=cfgs.BATCH_SIZE,

shortside_len=cfgs.SHORT_SIDE_LEN,

is_training=False)

# ***********************************************************************************************

# * share net *

# ***********************************************************************************************

_, share_net = get_network_byname(net_name=cfgs.NET_NAME,

inputs=img_batch,

num_classes=None,

is_training=True,

output_stride=None,

global_pool=False,

spatial_squeeze=False)

# ***********************************************************************************************

# * RPN *

# ***********************************************************************************************

rpn = build_rpn.RPN(net_name=cfgs.NET_NAME,

inputs=img_batch,

gtboxes_and_label=None,

is_training=False,

share_head=True,

share_net=share_net,

stride=cfgs.STRIDE,

anchor_ratios=cfgs.ANCHOR_RATIOS,

anchor_scales=cfgs.ANCHOR_SCALES,

scale_factors=cfgs.SCALE_FACTORS,

base_anchor_size_list=cfgs.BASE_ANCHOR_SIZE_LIST, # P2, P3, P4, P5, P6

level=cfgs.LEVEL,

top_k_nms=cfgs.RPN_TOP_K_NMS,

rpn_nms_iou_threshold=cfgs.RPN_NMS_IOU_THRESHOLD,

max_proposals_num=cfgs.MAX_PROPOSAL_NUM,

rpn_iou_positive_threshold=cfgs.RPN_IOU_POSITIVE_THRESHOLD,

rpn_iou_negative_threshold=cfgs.RPN_IOU_NEGATIVE_THRESHOLD,

rpn_mini_batch_size=cfgs.RPN_MINIBATCH_SIZE,

rpn_positives_ratio=cfgs.RPN_POSITIVE_RATE,

remove_outside_anchors=False, # whether remove anchors outside

rpn_weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME])

# rpn predict proposals

rpn_proposals_boxes, rpn_proposals_scores = rpn.rpn_proposals() # rpn_score shape: [300, ]

# ***********************************************************************************************

# * Fast RCNN *

# ***********************************************************************************************

fast_rcnn = build_fast_rcnn.FastRCNN(img_batch=img_batch,

feature_pyramid=rpn.feature_pyramid,

rpn_proposals_boxes=rpn_proposals_boxes,

rpn_proposals_scores=rpn_proposals_scores,

img_shape=tf.shape(img_batch),

roi_size=cfgs.ROI_SIZE,

scale_factors=cfgs.SCALE_FACTORS,

roi_pool_kernel_size=cfgs.ROI_POOL_KERNEL_SIZE,

gtboxes_and_label=None,

fast_rcnn_nms_iou_threshold=cfgs.FAST_RCNN_NMS_IOU_THRESHOLD,

fast_rcnn_maximum_boxes_per_img=100,

fast_rcnn_nms_max_boxes_per_class=cfgs.FAST_RCNN_NMS_MAX_BOXES_PER_CLASS,

show_detections_score_threshold=cfgs.FINAL_SCORE_THRESHOLD, # show detections which score >= 0.6

num_classes=cfgs.CLASS_NUM,

fast_rcnn_minibatch_size=cfgs.FAST_RCNN_MINIBATCH_SIZE,

fast_rcnn_positives_ratio=cfgs.FAST_RCNN_POSITIVE_RATE,

fast_rcnn_positives_iou_threshold=cfgs.FAST_RCNN_IOU_POSITIVE_THRESHOLD,

use_dropout=False,

weight_decay=cfgs.WEIGHT_DECAY[cfgs.NET_NAME],

is_training=False,

level=cfgs.LEVEL)

fast_rcnn_decode_boxes, fast_rcnn_score, num_of_objects, detection_category = \

fast_rcnn.fast_rcnn_predict()

# train

init_op = tf.group(

tf.global_variables_initializer(),

tf.local_variables_initializer()

)

restorer, restore_ckpt = restore_model.get_restorer(checkpoint_path=args.weights)

config = tf.ConfigProto()

# config.gpu_options.per_process_gpu_memory_fraction = 0.5

config.gpu_options.allow_growth = True

with tf.Session(config=config) as sess:

sess.run(init_op)

if not restorer is None:

restorer.restore(sess, restore_ckpt)

print('restore model')

coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(sess, coord)

img_name_batchs = []

all_boxes = []

for i in range(args.img_num):

start = time.time()

_img_name_batch, _img_batch, _gtboxes_and_label_batch, _fast_rcnn_decode_boxes, \

_fast_rcnn_score, _detection_category \

= sess.run([img_name_batch, img_batch, gtboxes_and_label_batch, fast_rcnn_decode_boxes,

fast_rcnn_score, detection_category])

end = time.time()

raw_img = cv2.imread(os.path.join(args.eval_imgs, _img_name_batch[0].decode('utf-8','ignore')))

raw_h, raw_w = raw_img.shape[0], raw_img.shape[1]

ymin, xmin, ymax, xmax = _fast_rcnn_decode_boxes[:, 0], _fast_rcnn_decode_boxes[:, 1], \

_fast_rcnn_decode_boxes[:, 2], _fast_rcnn_decode_boxes[:, 3]

resized_h, resized_w = _img_batch.shape[1], _img_batch.shape[2]

xmin = xmin * raw_w / resized_w

xmax = xmax * raw_w / resized_w

ymin = ymin * raw_h / resized_h

ymax = ymax * raw_h / resized_h

if len(_detection_category) != 0:

if draw_imgs:

show_indices = _fast_rcnn_score >= cfgs.FINAL_SCORE_THRESHOLD

show_scores = _fast_rcnn_score[show_indices]

show_boxes = _fast_rcnn_decode_boxes[show_indices]

show_categories = _detection_category[show_indices]

pre_ymin, pre_xmin, pre_ymax, pre_xmax = _fast_rcnn_decode_boxes[:, 0], _fast_rcnn_decode_boxes[:, 1], \

_fast_rcnn_decode_boxes[:, 2], _fast_rcnn_decode_boxes[:, 3]

pre_box = np.stack([pre_xmin, pre_ymin, pre_xmax, pre_ymax], axis=1)

final_detections = draw_box_in_img.draw_boxes_with_label_and_scores(np.squeeze(_img_batch, 0),

boxes=pre_box,

labels=show_categories,

scores=show_scores)

gt_ymin, gt_xmin, gt_ymax, gt_xmax = _gtboxes_and_label_batch[0][:, 0], _gtboxes_and_label_batch[0][:, 1], \

_gtboxes_and_label_batch[0][:, 2], _gtboxes_and_label_batch[0][:, 3],

gt_box = np.stack([gt_xmin, gt_ymin, gt_xmax, gt_ymax], axis=1)

final_detectionsgt = draw_box_in_img.draw_boxes_with_label_and_scores(np.squeeze(_img_batch, 0),

boxes=gt_box,

labels=_gtboxes_and_label_batch[0][:,-1],

scores=[1]*len(_gtboxes_and_label_batch))

cv2.imwrite('./eval/images/' + _img_name_batch[0].decode('utf-8','ignore'),

final_detections[:, :, ::-1])

cv2.imwrite('./eval/images/' + _img_name_batch[0].decode('utf-8', 'ignore').split('.jpg')[0] + 'gt.jpg',

final_detectionsgt[:, :, ::-1])

boxes = np.transpose(np.stack([xmin, ymin, xmax, ymax]))

dets = np.hstack((_detection_category.reshape(-1, 1),

_fast_rcnn_score.reshape(-1, 1),

boxes))

img_name_batchs.append(_img_name_batch[0])

all_boxes.append(dets)

view_bar('{} image cost {}s'.format(

str(_img_name_batch[0]), (end - start)), i + 1, args.img_num)

fw1 = open( 'detections.pkl', 'wb')

pickle.dump(all_boxes, fw1)

fw1.close()

coord.request_stop()

coord.join(threads)

return img_name_batchs

def eval(args):

print('Called with args:')

img_name_batchs = eval_dict_convert(args)

test_imgname_list = [img.decode('utf-8','ignore') for img in img_name_batchs]

with open('detections.pkl','rb') as f:

all_boxes = pickle.load(f,encoding='iso-8859-1')

voc_eval.voc_evaluate_detections(all_boxes=all_boxes,

test_annotation_path=args.annotation_dir,

test_imgid_list=test_imgname_list)

def parse_args():

"""

Parse input arguments

"""

parser = argparse.ArgumentParser(description='Evaluate a trained FPN model')

parser.add_argument('--eval_imgs', dest='eval_imgs',

help='evaluate imgs dir ',

default='../data/io/VOCdevkit_test/JPEGImages', type=str)

parser.add_argument('--annotation_dir', dest='annotation_dir',

help='the dir save annotations',

default='../data/io/VOCdevkit_test/Annotations', type=str)

parser.add_argument('--weights', dest='weights',

help='model path',

default='../output/airplane/res101_trained_weights/v3_airplane/airplane_115500model.ckpt',

type=str)

parser.add_argument('--img_num', dest='img_num',

help='image numbers',

default=300, type=int)

args = parser.parse_args()

return args

if __name__ == "__main__":

args = parse_args()

eval(args)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值