- 从数据集中选出自己需要的类别
import os
import cv2
import shutil
catogary = ['bridge'] #列表
def customname(fullname):
"""返回不带后缀的文件名"""
return os.path.basename(os.path.splitext(fullname)[0])
def GetFileFromRoot(dir):
"""获得每个文件的完整路径,包括后缀"""
allfiles = []
for root, dirs, files in os.walk(dir):
for file in files:
file_path = os.path.join(root, file)
allfiles.append(file_path)
return allfiles
if __name__ == '__main__':
root = 'E:/Aerial Images/Aerial Images/DOTA/train'
raw_pic_path = os.path.join(root, 'images/images')
raw_lab_path = os.path.join(root, 'labelTxt-v1.0/labelTxt')
bridge_pic = os.path.join(root, 'bridge/images')
bridge_lab = os.path.join(root, 'bridge/labelTxt')
label_list = GetFileFromRoot(raw_lab_path)
for label_path in label_list:
n = 0
f = open(label_path, 'r')
lines = f.readlines()
split_lines = (line.strip().split(' ') for line in lines) #strip 移除字符串头尾指定字符,默认空格,换行符或字符序列;根据空格来分割
for i, split_line in enumerate(split_lines):
if i in [0, 1]: #标签文本前两行为格式及高度,无用
continue
catogary_name = split_line[8] #类别
if catogary_name in catogary:
n = n + 1
if n > 1: #所要求类别目标数量达到两个就可以将该图像挑选出来
name = customname(label_path) #不带后缀的标签文件名
old_label_path = label_path
old_img_path = os.path.join(raw_pic_path, name + '.png')
img = cv2.imread(old_img_path)
new_lab_path = os.path.join(bridge_lab, name + 'txt')
new_pic_path = os.path.join(bridge_pic, name + '.png')
cv2.imwrite(new_pic_path, img)
shutil.copyfile(old_label_path, new_lab_path)
- 删除数据集中的空白样本
import os
import shutil
import xml.dom.minidom
def custombasename(fullname):
return os.path.basename(os.path.splitext(fullname)[0])
def GetFileFromThisRootDir(dir,ext = None):
allfiles = []
needExtFilter = (ext != None)
for root,dirs,files in os.walk(dir):
for filespath in files:
filepath = os.path.join(root, filespath)
extension = os.path.splitext(filepath)[1][1:]
if needExtFilter and extension in ext:
allfiles.append(filepath)
elif not needExtFilter:
allfiles.append(filepath)
return allfiles
def cleandata(path, img_path, blank_label_path, blank_img_path, ext, label_ext):
name = custombasename(path) #名称
if label_ext == 'xml':
DomTree = xml.dom.minidom.parse(path)
annotation = DomTree.documentElement
objectlist = annotation.getElementsByTagName('object')
if len(objectlist) == 0:
image_path = os.path.join(img_path, name + ext) #样本图片的名称
shutil.move(image_path, blank_img_path) #移动该样本图片到blank_img_path
shutil.move(path, blank_label_path) #移动该样本图片的标签到blank_label_path
else:
f_in = open(path, 'r') #打开label文件
lines = f_in.readlines()
if len(lines) == 0: #如果为空
f_in.close()
image_path = os.path.join(img_path, name + ext) #样本图片的名称
shutil.move(image_path, blank_img_path) #移动该样本图片到blank_img_path
shutil.move(path, blank_label_path) #移动该样本图片的标签到blank_label_path
print('正在处理 %s'%path)
if __name__ == '__main__':
root = 'E:/Aerial Images/Aerial Images/trainsplit'
img_path = os.path.join(root, 'images') #分割后的样本集
label_path = os.path.join(root, 'labelTxt') #分割后的标签
ext = '.png' #图片的后缀
label_ext = '.txt'
#空白的样本及标签
blank_img_path = os.path.join(root, 'blank_images')
blank_label_path = os.path.join(root, 'blank_labelTxt')
if not os.path.exists(blank_img_path):
os.makedirs(blank_img_path)
if not os.path.exists(blank_label_path):
os.makedirs(blank_label_path)
label_list = GetFileFromThisRootDir(label_path)
for path in label_list:
cleandata(path, img_path, blank_label_path, blank_img_path, ext, label_ext)
- 删除数据中的非目标样本(提取出含所需目标的样本)
import os
import shutil
import xml.dom.minidom
#n = 0
def custombasename(fullname):
return os.path.basename(os.path.splitext(fullname)[0])
def GetFileFromThisRootDir(dir, ext=None):
allfiles = []
needExtFilter = (ext != None)
for root, dirs, files in os.walk(dir):
for filespath in files:
filepath = os.path.join(root, filespath)
extension = os.path.splitext(filepath)[1][1:]
if needExtFilter and extension in ext:
allfiles.append(filepath)
elif not needExtFilter:
allfiles.append(filepath)
return allfiles
def cleandata(path, img_path, nontarget_label_path, nontarget_img_path, ext, label_ext):
name = custombasename(path) # 名称
n = 0
f_in = open(path, 'r') # 打开label文件
lines = f_in.readlines()
splitlines = [line.strip().split(' ') for line in lines]
for i, splitline in enumerate(splitlines):
catogory_name = splitline[8]
if catogory_name in catogory:
n = n + 1
if n > 0:
f_in.close()
image_path = os.path.join(img_path, name + ext) # 样本图片的名称
shutil.move(image_path, nontarget_img_path) # 移动该样本图片到blank_img_path
shutil.move(path, nontarget_label_path) # 移动该样本图片的标签到blank_label_path
break
print('正在处理 %s' % path)
if __name__ == '__main__':
catogory = ['bridge']
root = r'H:\DOTA\dota\trainsplit'
img_path = os.path.join(root, 'images') # 分割后的样本集
label_path = os.path.join(root, 'labelTxt') # 分割后的标签
ext = '.png' # 图片的后缀
label_ext = '.txt'
# 空白的样本及标签
nontarget_img_path = os.path.join(root, 'nontarget_images')
nontarget_label_path = os.path.join(root, 'nontarget_labelTxt')
if not os.path.exists(nontarget_img_path):
os.makedirs(nontarget_img_path)
if not os.path.exists(nontarget_label_path):
os.makedirs(nontarget_label_path)
label_list = GetFileFromThisRootDir(label_path)
for path in label_list:
cleandata(path, img_path, nontarget_label_path, nontarget_img_path, ext, label_ext)
- 将dota数据集标签格式从txt转换成xml
import os
import cv2
from xml.dom.minidom import Document
category_set = ['bridge']
def custombasename(fullname):
return os.path.basename(os.path.splitext(fullname)[0])
def limit_value(a, b):
if a < 1:
a = 1
if a >= b:
a = b - 1
return a
def readlabeltxt(txtpath, height, width, hbb=True):
print(txtpath)
with open(txtpath, 'r') as f_in: # 打开txt文件
lines = f_in.readlines()
splitlines = [x.strip().split(' ') for x in lines] # 根据空格分割
boxes = []
for i, splitline in enumerate(splitlines):
# if i in [0, 1]: # DOTA数据集前两行对于我们来说是无用的
# continue
label = splitline[8]
if label not in category_set: # 只书写制定的类别
continue
x1 = int(float(splitline[0]))
y1 = int(float(splitline[1]))
x2 = int(float(splitline[2]))
y2 = int(float(splitline[3]))
x3 = int(float(splitline[4]))
y3 = int(float(splitline[5]))
x4 = int(float(splitline[6]))
y4 = int(float(splitline[7]))
# 如果是hbb
if hbb:
xx1 = min(x1, x2, x3, x4)
xx2 = max(x1, x2, x3, x4)
yy1 = min(y1, y2, y3, y4)
yy2 = max(y1, y2, y3, y4)
xx1 = limit_value(xx1, width)
xx2 = limit_value(xx2, width)
yy1 = limit_value(yy1, height)
yy2 = limit_value(yy2, height)
box = [xx1, yy1, xx2, yy2, label]
boxes.append(box)
else: # 否则是obb
x1 = limit_value(x1, width)
y1 = limit_value(y1, height)
x2 = limit_value(x2, width)
y2 = limit_value(y2, height)
x3 = limit_value(x3, width)
y3 = limit_value(y3, height)
x4 = limit_value(x4, width)
y4 = limit_value(y4, height)
box = [x1, y1, x2, y2, x3, y3, x4, y4, label]
boxes.append(box)
return boxes
def writeXml(tmp, imgname, w, h, d, bboxes, hbb=True):
doc = Document()
# owner
annotation = doc.createElement('annotation')
doc.appendChild(annotation)
# owner
folder = doc.createElement('folder')
annotation.appendChild(folder)
folder_txt = doc.createTextNode("VOC2007")
folder.appendChild(folder_txt)
filename = doc.createElement('filename')
annotation.appendChild(filename)
filename_txt = doc.createTextNode(imgname)
filename.appendChild(filename_txt)
# ones#
source = doc.createElement('source')
annotation.appendChild(source)
database = doc.createElement('database')
source.appendChild(database)
database_txt = doc.createTextNode("My Database")
database.appendChild(database_txt)
annotation_new = doc.createElement('annotation')
source.appendChild(annotation_new)
annotation_new_txt = doc.createTextNode("VOC2007")
annotation_new.appendChild(annotation_new_txt)
image = doc.createElement('image')
source.appendChild(image)
image_txt = doc.createTextNode("flickr")
image.appendChild(image_txt)
# owner
owner = doc.createElement('owner')
annotation.appendChild(owner)
flickrid = doc.createElement('flickrid')
owner.appendChild(flickrid)
flickrid_txt = doc.createTextNode("NULL")
flickrid.appendChild(flickrid_txt)
ow_name = doc.createElement('name')
owner.appendChild(ow_name)
ow_name_txt = doc.createTextNode("idannel")
ow_name.appendChild(ow_name_txt)
# onee#
# twos#
size = doc.createElement('size')
annotation.appendChild(size)
width = doc.createElement('width')
size.appendChild(width)
width_txt = doc.createTextNode(str(w))
width.appendChild(width_txt)
height = doc.createElement('height')
size.appendChild(height)
height_txt = doc.createTextNode(str(h))
height.appendChild(height_txt)
depth = doc.createElement('depth')
size.appendChild(depth)
depth_txt = doc.createTextNode(str(d))
depth.appendChild(depth_txt)
# twoe#
segmented = doc.createElement('segmented')
annotation.appendChild(segmented)
segmented_txt = doc.createTextNode("0")
segmented.appendChild(segmented_txt)
for bbox in bboxes:
# threes#
object_new = doc.createElement("object")
annotation.appendChild(object_new)
name = doc.createElement('name')
object_new.appendChild(name)
name_txt = doc.createTextNode(str(bbox[-1]))
name.appendChild(name_txt)
pose = doc.createElement('pose')
object_new.appendChild(pose)
pose_txt = doc.createTextNode("Unspecified")
pose.appendChild(pose_txt)
truncated = doc.createElement('truncated')
object_new.appendChild(truncated)
truncated_txt = doc.createTextNode("0")
truncated.appendChild(truncated_txt)
difficult = doc.createElement('difficult')
object_new.appendChild(difficult)
difficult_txt = doc.createTextNode("0")
difficult.appendChild(difficult_txt)
# threes-1#
bndbox = doc.createElement('bndbox')
object_new.appendChild(bndbox)
if hbb:
xmin = doc.createElement('xmin')
bndbox.appendChild(xmin)
xmin_txt = doc.createTextNode(str(bbox[0]))
xmin.appendChild(xmin_txt)
ymin = doc.createElement('ymin')
bndbox.appendChild(ymin)
ymin_txt = doc.createTextNode(str(bbox[1]))
ymin.appendChild(ymin_txt)
xmax = doc.createElement('xmax')
bndbox.appendChild(xmax)
xmax_txt = doc.createTextNode(str(bbox[2]))
xmax.appendChild(xmax_txt)
ymax = doc.createElement('ymax')
bndbox.appendChild(ymax)
ymax_txt = doc.createTextNode(str(bbox[3]))
ymax.appendChild(ymax_txt)
else:
x0 = doc.createElement('x0')
bndbox.appendChild(x0)
x0_txt = doc.createTextNode(str(bbox[0]))
x0.appendChild(x0_txt)
y0 = doc.createElement('y0')
bndbox.appendChild(y0)
y0_txt = doc.createTextNode(str(bbox[1]))
y0.appendChild(y0_txt)
x1 = doc.createElement('x1')
bndbox.appendChild(x1)
x1_txt = doc.createTextNode(str(bbox[2]))
x1.appendChild(x1_txt)
y1 = doc.createElement('y1')
bndbox.appendChild(y1)
y1_txt = doc.createTextNode(str(bbox[3]))
y1.appendChild(y1_txt)
x2 = doc.createElement('x2')
bndbox.appendChild(x2)
x2_txt = doc.createTextNode(str(bbox[4]))
x2.appendChild(x2_txt)
y2 = doc.createElement('y2')
bndbox.appendChild(y2)
y2_txt = doc.createTextNode(str(bbox[5]))
y2.appendChild(y2_txt)
x3 = doc.createElement('x3')
bndbox.appendChild(x3)
x3_txt = doc.createTextNode(str(bbox[6]))
x3.appendChild(x3_txt)
y3 = doc.createElement('y3')
bndbox.appendChild(y3)
y3_txt = doc.createTextNode(str(bbox[7]))
y3.appendChild(y3_txt)
xmlname = os.path.splitext(imgname)[0]
tempfile = os.path.join(tmp, xmlname + '.xml')
with open(tempfile, 'wb') as f:
f.write(doc.toprettyxml(indent='\t', encoding='utf-8'))
return
if __name__ == '__main__':
data_path = r'E:\Aerial Images\Aerial Images\DOTA\val\bridge\valsplit'
images_path = os.path.join(data_path, 'images') # 样本图片路径
labeltxt_path = os.path.join(data_path, 'labelTxt') # DOTA标签的所在路径
anno_new_path = os.path.join(data_path, 'hbbxml') # 新的voc格式存储位置(hbb形式)
ext = '.png' # 样本图片的后缀
filenames = os.listdir(labeltxt_path) # 获取每一个txt的名称
for filename in filenames:
filepath = labeltxt_path + '/' + filename # 每一个DOTA标签的具体路径
picname = os.path.splitext(filename)[0] + ext
pic_path = os.path.join(images_path, picname)
im = cv2.imread(pic_path) # 读取相应的图片
(H, W, D) = im.shape # 返回样本的大小
boxes = readlabeltxt(filepath, H, W, hbb=True) # 默认是矩形(hbb)得到gt
if len(boxes) == 0:
print('文件为空', filepath)
# 读取对应的样本图片,得到H,W,D用于书写xml
# 书写xml
writeXml(anno_new_path, picname, W, H, D, boxes, hbb=True)
print('正在处理%s' % filename)
需要注意文件夹路径、目标类别、图像格式、注释框格式(hbb还是obb)
- xml到csv格式
import os
import glob #文件操作相关模块,用它可以查找符合自己目的的文件
import pandas as pd
import xml.etree.ElementTree as ET
os.chdir(r'E:\Aerial Images\Aerial Images\DOTA\val\bridge\valsplit\hbbxml')
path = r'E:\Aerial Images\Aerial Images\DOTA\val\bridge\valsplit\hbbxml'
def xml_to_csv(path):
xml_list = []
for xml_file in glob.glob(path + '/*.xml'): #获得指定路径下所有的.XML文件
tree = ET.parse(xml_file) #分析指定的XML文件(获取XML文档对象 )
root = tree.getroot() #获取XML文档对象的根节点
for member in root.findall('object'):
value = (root.find('filename').text, #获得文件名(图片名)
int(root.find('size')[0].text), #图片宽和高
int(root.find('size')[1].text),
member[0].text, #类别
int(member[4][0].text), #目标位置
int(member[4][1].text),
int(member[4][2].text),
int(member[4][3].text)
)
xml_list.append(value)
column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax'] #csv各列名,
xml_df = pd.DataFrame(xml_list, columns=column_name) #第一个参数是待存放数据,后两个参数是行和列的名,可以使用list输入
return xml_df
def main():
image_path = path
xml_df = xml_to_csv(image_path)
xml_df.to_csv('label.csv', index=None)
print('Successfully converted xml to csv.')
main()
需要注意的是
(1)、column_name = [‘filename’, ‘width’, ‘height’, ‘class’, ‘xmin’, ‘ymin’, ‘xmax’, ‘ymax’]与member中的元素的对应关系
- csv到tfrecord(用于tensorflow训练的格式)
"""
Usage:
python csv_to_tfrecord.py --csv_input=data/train_labels.csv --output_path=train_label.record
python csv_to_tfrecord.py --csv_input=data/val_labels.csv --output_path=val_labels.record
"""
import os
import io
import pandas as pd
import tensorflow as tf
from PIL import Image
from collections import namedtuple, OrderedDict
from object_detection.utils import dataset_util
flags = tf.app.flags
flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
flags.DEFINE_string('output_path', '', 'Path to the tfrecord output')
FLAGS = flags.FLAGS
os.chdir('C:\\Users\\DL-1\\models\\research\\object_detection\\')
# TO-DO replace this with label map
def class_text_to_int(row_label):
if row_label == 'bridge':
return 1
# elif row_label == 'vehicle':
# return 2
else:
None
def split(df, group):
data = namedtuple('data', ['filename', 'object'])
gb = df.groupby(group)
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
def create_tf_example(group, path):
with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
encoded_png = fid.read()
encoded_png_io = io.BytesIO(encoded_png)
image = Image.open(encoded_png_io)
width, height = image.size
filename = group.filename.encode('utf8')
image_format = b'png'
xmins = []
xmaxs = []
ymins = []
ymaxs = []
classes_text = []
classes = []
for index, row in group.object.iterrows():
xmins.append(row['xmin'] / width)
xmaxs.append(row['xmax'] / width)
ymins.append(row['ymin'] / height)
ymaxs.append(row['ymax'] / height)
classes_text.append(row['class'].encode('utf8'))
classes.append(class_text_to_int(row['class']))
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_png),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
def main(_):
writer = tf.io.TFRecordWriter(FLAGS.output_path)
path = os.path.join(os.getcwd(), 'images/bridge_val') # 获取当前工作目录
examples = pd.read_csv(FLAGS.csv_input)
grouped = split(examples, 'filename')
for group in grouped:
tf_example = create_tf_example(group, path)
writer.write(tf_example.SerializeToString())
writer.close()
output_path = os.path.join(os.getcwd(), FLAGS.output_path)
print('Successfully created the TFRecords: {}'.format(output_path))
if __name__ == '__main__':
tf.app.run()
需要注意修改的地方是
(1)、图像的目录path = os.path.join(os.getcwd(), ‘images/bridge_val’) # 获取当前工作目录
(2)、图像后缀名(格式)image_format = b’png’
(3)、对应的目标类别。 if row_label == ‘bridge’:
7、devkit/dota_evaluation_task2.py
"""
To use the code, users should to config detpath, annopath and imagesetfile
detpath is the path for 15 result files, for the format, you can refer to "http://captain.whu.edu.cn/DOTAweb/tasks.html"
search for PATH_TO_BE_CONFIGURED to config the paths
Note, the evaluation is on the large scale images
"""
import xml.etree.ElementTree as ET
import os
#import cPickle
import numpy as np
import matplotlib.pyplot as plt
def parse_gt(filename):
objects = []
with open(filename, 'r') as f:
lines = f.readlines()
splitlines = [x.strip().split(' ') for x in lines]
for splitline in splitlines:
object_struct = {}
object_struct['name'] = splitline[8]
if (len(splitline) == 9):
object_struct['difficult'] = 0
elif (len(splitline) == 10):
object_struct['difficult'] = int(splitline[9])
# object_struct['difficult'] = 0
object_struct['bbox'] = [int(float(splitline[0])),
int(float(splitline[1])),
int(float(splitline[4])),
int(float(splitline[5]))]
w = int(float(splitline[4])) - int(float(splitline[0]))
h = int(float(splitline[5])) - int(float(splitline[1]))
object_struct['area'] = w * h
#print('area:', object_struct['area'])
# if object_struct['area'] < (15 * 15):
# #print('area:', object_struct['area'])
# object_struct['difficult'] = 1
objects.append(object_struct)
return objects
def voc_ap(rec, prec, use_07_metric=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def voc_eval(detpath,
annopath,
imagesetfile,
classname,
# cachedir,
ovthresh=0.5,
use_07_metric=False):
"""rec, prec, ap = voc_eval(detpath,
annopath,
imagesetfile,
classname,
[ovthresh],
[use_07_metric])
Top level function that does the PASCAL VOC evaluation.
detpath: Path to detections
detpath.format(classname) should produce the detection results file.
annopath: Path to annotations
annopath.format(imagename) should be the xml annotations file.
imagesetfile: Text file containing the list of images, one image per line.
classname: Category name (duh)
cachedir: Directory for caching the annotations
[ovthresh]: Overlap threshold (default = 0.5)
[use_07_metric]: Whether to use VOC07's 11 point AP computation
(default False)
"""
# assumes detections are in detpath.format(classname)
# assumes annotations are in annopath.format(imagename)
# assumes imagesetfile is a text file with each line an image name
# cachedir caches the annotations in a pickle file
# first load gt
#if not os.path.isdir(cachedir):
# os.mkdir(cachedir)
#cachefile = os.path.join(cachedir, 'annots.pkl')
# read list of images
with open(imagesetfile, 'r') as f:
lines = f.readlines()
imagenames = [x.strip() for x in lines]
#print('imagenames: ', imagenames)
#if not os.path.isfile(cachefile):
# load annots
recs = {}
for i, imagename in enumerate(imagenames):
#print('parse_files name: ', annopath.format(imagename))
recs[imagename] = parse_gt(annopath.format(imagename))
#if i % 100 == 0:
# print ('Reading annotation for {:d}/{:d}'.format(
# i + 1, len(imagenames)) )
# save
#print ('Saving cached annotations to {:s}'.format(cachefile))
#with open(cachefile, 'w') as f:
# cPickle.dump(recs, f)
#else:
# load
#with open(cachefile, 'r') as f:
# recs = cPickle.load(f)
# extract gt objects for this class
class_recs = {}
npos = 0
for imagename in imagenames:
R = [obj for obj in recs[imagename] if obj['name'] == classname]
bbox = np.array([x['bbox'] for x in R])
difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
det = [False] * len(R)
npos = npos + sum(~difficult)
class_recs[imagename] = {'bbox': bbox,
'difficult': difficult,
'det': det}
# read dets
detfile = detpath.format(classname)
with open(detfile, 'r') as f:
lines = f.readlines()
splitlines = [x.strip().split(' ') for x in lines]
image_ids = [x[0] for x in splitlines]
confidence = np.array([float(x[1]) for x in splitlines])
#print('check confidence: ', confidence)
BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
# sort by confidence
sorted_ind = np.argsort(-confidence)
sorted_scores = np.sort(-confidence)
#print('check sorted_scores: ', sorted_scores)
#print('check sorted_ind: ', sorted_ind)
BB = BB[sorted_ind, :]
image_ids = [image_ids[x] for x in sorted_ind]
#print('check imge_ids: ', image_ids)
#print('imge_ids len:', len(image_ids))
# go down dets and mark TPs and FPs
nd = len(image_ids)
tp = np.zeros(nd)
fp = np.zeros(nd)
for d in range(nd):
R = class_recs[image_ids[d]]
bb = BB[d, :].astype(float)
ovmax = -np.inf
BBGT = R['bbox'].astype(float)
if BBGT.size > 0:
# compute overlaps
# intersection
ixmin = np.maximum(BBGT[:, 0], bb[0])
iymin = np.maximum(BBGT[:, 1], bb[1])
ixmax = np.minimum(BBGT[:, 2], bb[2])
iymax = np.minimum(BBGT[:, 3], bb[3])
iw = np.maximum(ixmax - ixmin + 1., 0.)
ih = np.maximum(iymax - iymin + 1., 0.)
inters = iw * ih
# union
uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
(BBGT[:, 2] - BBGT[:, 0] + 1.) *
(BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
overlaps = inters / uni
ovmax = np.max(overlaps)
## if there exist 2
jmax = np.argmax(overlaps)
if ovmax > ovthresh:
if not R['difficult'][jmax]:
if not R['det'][jmax]:
tp[d] = 1.
R['det'][jmax] = 1
else:
fp[d] = 1.
# print('filename:', image_ids[d])
else:
fp[d] = 1.
# compute precision recall
print('check fp:', fp)
print('check tp', tp)
print('npos num:', npos)
fp = np.cumsum(fp)
tp = np.cumsum(tp)
rec = tp / float(npos)
# avoid divide by zero in case the first detection matches a difficult
# ground truth
prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = voc_ap(rec, prec, use_07_metric)
return rec, prec, ap
def main():
# detpath = r'E:\documentation\OneDrive\documentation\DotaEvaluation\evluation_task2\evluation_task2\faster-rcnn-nms_0.3_task2\nms_0.3_task\Task2_{:s}.txt'
# annopath = r'I:\dota\testset\ReclabelTxt-utf-8\{:s}.txt'
# imagesetfile = r'I:\dota\testset\va.txt'
detpath = r'H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_dt\Task2_{:s}.txt'
annopath = r'H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_gt\{:s}.txt'# change the directory to the path of val/labelTxt, if you want to do evaluation on the valset
imagesetfile = r'H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_images\val_bridge_image.txt'
classnames = ['plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter']
classaps = []
map = 0
for classname in classnames:
print('classname:', classname)
rec, prec, ap = voc_eval(detpath,
annopath,
imagesetfile,
classname,
ovthresh=0.5,
use_07_metric=True)
map = map + ap
#print('rec: ', rec, 'prec: ', prec, 'ap: ', ap)
print('ap: ', ap)
classaps.append(ap)
## uncomment to plot p-r curve for each category
# plt.figure(figsize=(8,4))
# plt.xlabel('recall')
# plt.ylabel('precision')
# plt.plot(rec, prec)
# plt.show()
map = map/len(classnames)
print('map:', map)
classaps = 100*np.array(classaps)
print('classaps: ', classaps)
if __name__ == '__main__':
main()
需要改动的地方主要有这三个:详情参考代码中的注释。
- detpath = r’H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_dt\Task2_{😒}.txt’
检测结果(detection_result)的路径,格式参考官网上给出的 - annopath = r’H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_gt{😒}.txt’
真实标签(ground_truth)的路径。 - imagesetfile =r’H:\DOTA\Raw_DOTA\evaluate_val_with_val\Task2_val_images\val_bridge_image.txt’
待检测图像名(不含后缀)组成的.txt文件。每行一个图像名。
注意:
代码中的classnames中类别要与检测结果中的类别数相符。如果要检测其中一类(plane),则,要么在检测结果文件夹中添加其他类别的检测结果(比如Tsak2_bridge.txt),而且其中一定要有内容,哪怕是错的,也无所谓,反正你的关注点在plane上,否则可能会报错:IndexError: too many indices for array。要么就将classnames中的其他类删除,只留你想验证的类别。
8小时Python零基础轻松入门