#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 7 10:10:12 2019
@author: fanzy
"""
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from utils import shapeData as dataSet
from config import Config
config = Config()
dataset = dataSet([64,64], config=config)
def data_Gen(dataset,num_batch,batch_size,config):
for _ in range(num_batch):
images=[]
bboxes=[]
class_ids=[]
target_matches=[]
rpn_bboxes=[]
for i in range(batch_size):
image, bbox, class_id, rpn_match, rpn_bbox, anchors = data = dataset.load_data()
pad_num = config.max_gt_obj - bbox.shape[0]
pad_box = np.zeros((pad_num, 4))
pad_ids = np.zeros((pad_num, 1))
bbox = np.concatenate([bbox, pad_box], axis=0)
class_id = np.concatenate([class_id, pad_ids], axis=0)
images.append(image)
bboxes.append(bbox)
class_ids.append(class_id)
target_matches.append(rpn_match)
rpn_bboxes.append(rpn_bbox)
images = np.concatenate(images, 0).reshape(batch_size, config.image_size[0],config.image_size[1] , 3)
bboxes = np.concatenate(bboxes, 0).reshape(batch_size, -1 , 4)
class_ids = np.concatenate(class_ids, 0).reshape(batch_size, -1 )
target_matches = np.concatenate(target_matches, 0).reshape(batch_size, -1 , 1)
rpn_bboxes = np.concatenate(rpn_bboxes, 0).reshape(batch_size, -1 , 4)
yield [images, bboxes, class_ids, target_matches, rpn_bboxes,anchors],[]
dataGen = data_Gen(dataset, 35000, 1, config)
#拿出一个数据,数据里面包括
#images--图片 bbox--即ground truth
#class--ground truth的类别(哪种物体0,1,2,3,4....)
#input_rpn_matchs--anchor的类别label(-1,0,1)
# IOU>0.7--label=1(正类), IOU<0.1--label=-1(负类), 0.1<IOU<0.7--label=0(学不到有用信息不参与训练)
# 只有1和-1label的anchor参与训练RPN分类,label=1的框是有限的,而=-1的框很多,所以限制-1和1的框的数量和为N=100
#input_rpn_bbox--label=1的anchor,只有label=1的参与回归训练(bbox regression),anchors-576个anchor
data_test=next(dataGen)
images=data_test[0][0]
bboxes=data_test[0][1]
class_ids=data_test[0][2]
input_rpn_matchs=data_test[0][3]
input_rpn_bbox=data_test[0][4]
anchors=data_test[0][5]
plt.imshow(images[0])
axs=plt.gca()
#取出label=1的anchor并画出来,数目为IOU>0.7的anchor数(比如6)
idxposi=np.where(input_rpn_matchs==1)
idx_posi=idxposi[1]
for i in range(idx_posi.shape[0]):
box=anchors[idx_posi[i]]
rec=patches.Rectangle((box[0],box[1]),box[2]-box[0],box[3]-box[1],facecolor='none',edgecolor='r')
axs.add_patch(rec)
#取出label=-1的anchor并画出来,数目=N-IOU>0.7的anchor数(N-6)
idxnega=np.where(input_rpn_matchs==-1)
idex_nega=idxnega[1]
for i in range(idex_nega.shape[0]):
box=anchors[idex_nega[i]]
rec=patches.Rectangle((box[0],box[1]),box[2]-box[0],box[3]-box[1],facecolor='none',edgecolor='b')
axs.add_patch(rec)