之前写了一篇文章介绍了如何在旋转目标检测框架MMRotate下训练S2anet模型,文章相关内容在网上也是有挺多资料的,但是关于如何利用训练好的模型进行预测,并且导出标注文件,网上相关资料偏少。由于项目需要,作者需要利用训练好的模型对多张图片进行预测,并导出标注文件(主要是需要标注框的位置信息),从而对图片进行斜框标注(当然后期还需要人工更正存在错误的标注)。本文记录了作者利用S2anet模型对图片进行预测并且导出相关的标注文件的全过程。
一、对多张图片进行批量预测
在MMRotate框架下给出了对单张图片进行预测的代码,具体位置在mmrotate/demo下的image_demo.py。我们需要对其代码进行修改从而使其能够进行批量预测,为了不影响其他代码,我们在mmrotate/demo下新建一个名为images_demo的py程序,在image_demo.py的基础上进行修改,具体代码如下:
from argparse import ArgumentParser
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
import mmrotate # noqa: F401
import os
import cv2
import mmcv
def parse_args():
parser = ArgumentParser()
parser.add_argument('--img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('--out-file', default=None, help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='dota',
choices=['dota', 'sar', 'hrsc', 'hrsc_classwise', 'random'],
help='Color palette used for visualization')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
parser.add_argument('--in-folder', help='in-images')
parser.add_argument('--out-folder', help='out-images')
args = parser.parse_args()
return args
def main(args):
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
results = [] # result file
# test images
if not os.path.exists(args.out_folder):
os.makedirs(args.out_folder)
for file_name in os.listdir(args.in_folder):
img_path = os.path.join(args.in_folder, file_name)
img = cv2.imread(img_path)
result = inference_detector(model, img)
results.append(file_name)
results.append(result)
out_file = os.path.join(args.out_folder, file_name)
# show the results
show_result_pyplot(
model,
img,
result,
palette=args.palette,
score_thr=args.score_thr,
out_file=out_file)
mmcv.dump(results, args.out_folder + '.pkl') #输出pkl文件
if __name__ == '__main__':
args = parse_args()
main(args)
在这段代码中,我主要修改了以下内容:
- 添加了以下两个参数,in-folder表示图片所在的路径。out-folder表示输出带有标注框的图片的路径。
- 新建一个results数组,用于将预测结果保存在其中,对图片所在的文件夹进行遍历,依次预测,并将预测结果保存在results中。但是这里我首先是将当前正在预测的图片名称先保存在results中,然后再保存当前图片的预测结果,这是为了再后续流程中清晰每个标注信息所对应的图片名称。最后将预测结果保存为一个pkl文件。
在终端中运行上述代码,文件路径需要自行修改:
python /home/hfut-dl/xzh/mmrotate/demo/images_demo.py --in-folder /home/hfut-dl/xzh/SAR-SHIP-2.0/images --out-folder /home/hfut-dl/xzh/SAR-SHIP-2.0/out /home/hfut-dl/xzh/mmrotate/configs/s2anet/s2anet_r50_fpn_1x_dota_le135.py /home/hfut-dl/xzh/model/output/latest.pth
运行之后并等待预测完毕,你会得到out.pkl文件,即存放标注信息的文件。
二、解析pkl文件
得到pkl文件后,我们可以先输出以下看看里面有哪些内容。
首先利用下面的代码输出pkl文件中的内容:
path = 'D:/Alinshi/SAR-Ship-Dataset/out.pkl' # pkl文件所在路径---修改1
f = open(path, 'rb')
datas = pickle.load(f)
print(datas)
我这边输出的结果是下图。输出结果中有很多数组还有图片的路径。
输出结果看起来非常杂乱,我们可以继续修改代码,让输出结果看起来更清晰:
for i in range(0, len(datas), 2):
objectList = []
# print(datas[i]) # 图片名称
print("{}中的标注信息为:".format(datas[i]))
for x in range(datas[i + 1][0].shape[0]):
arrys = []
for y in range(datas[i + 1][0].shape[1]):
print(datas[i + 1][0][x, y], end=" ")
arrys.append(datas[i + 1][0][x, y])
print() # 标注信息 cx cy h w angle 置信度
此时,输出的结果如下图,感觉清晰了很多。(图中是两个图片的预测结果)
可以看到,对于每张图片,首先输出图片的名称,然后再输出标注框的位置信息,有几行就代表有多少个标注框。可以看到,每一行有6个数字,这6个数字依次是cx,cy,h,w,angle和置信度。
三、将解析出的信息写入到xml文件中
因为我最终需要得到每一张图片的标注文件,因此需要将这些标注信息写入到xml文件中。项目要求我将xml文件的格式统一为RSDD-SAR这个数据集的xml文件格式,格式如下:
继续修改代码如下:
for i in range(0, len(datas), 2):
objectList = []
# print(datas[i]) # 图片名称
print("{}中的标注信息为:".format(datas[i]))
for x in range(datas[i + 1][0].shape[0]):
arrys = []
for y in range(datas[i + 1][0].shape[1]):
print(datas[i + 1][0][x, y], end=" ")
arrys.append(datas[i + 1][0][x, y])
print() # 标注信息 cx cy h w angle 置信度
info = [{"cx": arrys[0], "cy": arrys[1], "h": arrys[2], "w": arrys[3], "angle": arrys[4]}]
if arrys[5] > 0.1: # 置信度大于0.1
objectList.append(info)
在这段代码中,我将每一张图片中的每一个标注信息以字典的形式保存在info中,然后将info添加到objectList数组中。这里我进行了一点处理,即首先判断arrys[5]的值,arrys[5]的值就是置信度,只有在置信度大于0.1时,才将其保存在objectList中。这里大家可以根据需要自行调整置信度。
之后就是将信息写入到xml文件中了,相关代码如下:
# 写入xml文件
def create_xml_file(file_dir, dataSetName, pictureName, width, height, depth, objectList):
# 创建根节点
with open(file_dir + '{}.xml'.format(pictureName), 'w') as xml_file:
xml_file.write('<annotation>\n')
xml_file.write(' <folder>{}</folder>\n'.format(dataSetName))
xml_file.write(' <filename>{}</filename>\n'.format(pictureName + ".xml"))
xml_file.write(' <size>\n')
xml_file.write(' <width>{}</width>\n'.format(width))
xml_file.write(' <height>{}</height>\n'.format(height))
xml_file.write(' <depth>{}</depth>\n'.format(depth))
xml_file.write(' </size>\n')
for objects in objectList:
xml_file.write(' <object>\n')
xml_file.write(' <type>{}</type>\n'.format('robndbox'))
xml_file.write(' <name>{}</name>\n'.format('ship'))
xml_file.write(' <difficult>{}</difficult>\n'.format('0'))
xml_file.write(' <robndbox>\n')
xml_file.write(' <cx>{}</cx>\n'.format(objects[0]["cx"]))
xml_file.write(' <cy>{}</cy>\n'.format(objects[0]["cy"]))
xml_file.write(' <h>{}</h>\n'.format(objects[0]["h"]))
xml_file.write(' <w>{}</w>\n'.format(objects[0]["w"]))
xml_file.write(' <angle>{}</angle>\n'.format(objects[0]["angle"]))
xml_file.write(' </robndbox>\n')
xml_file.write(' </object>\n')
xml_file.write('</annotation>')
这段代码会遍历objectList数组,将其中的信息都写入到object标签下的对应小标签中。
四、总结
感觉网上关于pkl文件解析的相关文章比较少(也许是我搜集资料的能力有限),希望能够帮到大家!有哪些不足之处也欢迎大家提出!