使用S2anet模型进行预测并解析pkl文件

        之前写了一篇文章介绍了如何在旋转目标检测框架MMRotate下训练S2anet模型,文章相关内容在网上也是有挺多资料的,但是关于如何利用训练好的模型进行预测,并且导出标注文件,网上相关资料偏少。由于项目需要,作者需要利用训练好的模型对多张图片进行预测,并导出标注文件(主要是需要标注框的位置信息),从而对图片进行斜框标注(当然后期还需要人工更正存在错误的标注)。本文记录了作者利用S2anet模型对图片进行预测并且导出相关的标注文件的全过程。

一、对多张图片进行批量预测

        在MMRotate框架下给出了对单张图片进行预测的代码,具体位置在mmrotate/demo下的image_demo.py。我们需要对其代码进行修改从而使其能够进行批量预测,为了不影响其他代码,我们在mmrotate/demo下新建一个名为images_demo的py程序,在image_demo.py的基础上进行修改,具体代码如下:

from argparse import ArgumentParser
from mmdet.apis import inference_detector, init_detector, show_result_pyplot
import mmrotate  # noqa: F401
import os
import cv2
import mmcv


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--img', help='Image file')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument('--out-file', default=None, help='Path to output file')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--palette',
        default='dota',
        choices=['dota', 'sar', 'hrsc', 'hrsc_classwise', 'random'],
        help='Color palette used for visualization')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='bbox score threshold')
    parser.add_argument('--in-folder', help='in-images')
    parser.add_argument('--out-folder', help='out-images')
    args = parser.parse_args()
    return args


def main(args):
    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)

    results = []  # result file

    # test images
    if not os.path.exists(args.out_folder):
        os.makedirs(args.out_folder)
    for file_name in os.listdir(args.in_folder):
        img_path = os.path.join(args.in_folder, file_name)
        img = cv2.imread(img_path)
        result = inference_detector(model, img)

        results.append(file_name)
        results.append(result)
        out_file = os.path.join(args.out_folder, file_name)
        # show the results
        show_result_pyplot(
            model,
            img,
            result,
            palette=args.palette,
            score_thr=args.score_thr,
            out_file=out_file)

    mmcv.dump(results, args.out_folder + '.pkl') #输出pkl文件


if __name__ == '__main__':
    args = parse_args()
    main(args)

        在这段代码中,我主要修改了以下内容:

  • 添加了以下两个参数,in-folder表示图片所在的路径。out-folder表示输出带有标注框的图片的路径。
  • 新建一个results数组,用于将预测结果保存在其中,对图片所在的文件夹进行遍历,依次预测,并将预测结果保存在results中。但是这里我首先是将当前正在预测的图片名称先保存在results中,然后再保存当前图片的预测结果,这是为了再后续流程中清晰每个标注信息所对应的图片名称。最后将预测结果保存为一个pkl文件。​​​​​

        在终端中运行上述代码,文件路径需要自行修改:

python /home/hfut-dl/xzh/mmrotate/demo/images_demo.py --in-folder /home/hfut-dl/xzh/SAR-SHIP-2.0/images --out-folder /home/hfut-dl/xzh/SAR-SHIP-2.0/out /home/hfut-dl/xzh/mmrotate/configs/s2anet/s2anet_r50_fpn_1x_dota_le135.py /home/hfut-dl/xzh/model/output/latest.pth

        运行之后并等待预测完毕,你会得到out.pkl文件,即存放标注信息的文件。

二、解析pkl文件

        得到pkl文件后,我们可以先输出以下看看里面有哪些内容。

        首先利用下面的代码输出pkl文件中的内容:

path = 'D:/Alinshi/SAR-Ship-Dataset/out.pkl'  # pkl文件所在路径---修改1

f = open(path, 'rb')
datas = pickle.load(f)
print(datas)

        我这边输出的结果是下图。输出结果中有很多数组还有图片的路径。

        输出结果看起来非常杂乱,我们可以继续修改代码,让输出结果看起来更清晰:

for i in range(0, len(datas), 2):
    objectList = []
    # print(datas[i])  # 图片名称
    print("{}中的标注信息为:".format(datas[i]))
    for x in range(datas[i + 1][0].shape[0]):
        arrys = []
        for y in range(datas[i + 1][0].shape[1]):
            print(datas[i + 1][0][x, y], end=" ")
            arrys.append(datas[i + 1][0][x, y])
        print()  # 标注信息 cx cy h w angle 置信度

        此时,输出的结果如下图,感觉清晰了很多。(图中是两个图片的预测结果)

        可以看到,对于每张图片,首先输出图片的名称,然后再输出标注框的位置信息,有几行就代表有多少个标注框。可以看到,每一行有6个数字,这6个数字依次是cx,cy,h,w,angle和置信度。

三、将解析出的信息写入到xml文件中

        因为我最终需要得到每一张图片的标注文件,因此需要将这些标注信息写入到xml文件中。项目要求我将xml文件的格式统一为RSDD-SAR这个数据集的xml文件格式,格式如下:

        继续修改代码如下:

for i in range(0, len(datas), 2):
    objectList = []
    # print(datas[i])  # 图片名称
    print("{}中的标注信息为:".format(datas[i]))
    for x in range(datas[i + 1][0].shape[0]):
        arrys = []
        for y in range(datas[i + 1][0].shape[1]):
            print(datas[i + 1][0][x, y], end=" ")
            arrys.append(datas[i + 1][0][x, y])
        print()  # 标注信息 cx cy h w angle 置信度
        info = [{"cx": arrys[0], "cy": arrys[1], "h": arrys[2], "w": arrys[3], "angle": arrys[4]}]
        if arrys[5] > 0.1:  # 置信度大于0.1
            objectList.append(info)

        在这段代码中,我将每一张图片中的每一个标注信息以字典的形式保存在info中,然后将info添加到objectList数组中。这里我进行了一点处理,即首先判断arrys[5]的值,arrys[5]的值就是置信度,只有在置信度大于0.1时,才将其保存在objectList中。这里大家可以根据需要自行调整置信度。

        之后就是将信息写入到xml文件中了,相关代码如下:

# 写入xml文件
def create_xml_file(file_dir, dataSetName, pictureName, width, height, depth, objectList):
    # 创建根节点
    with open(file_dir + '{}.xml'.format(pictureName), 'w') as xml_file:
        xml_file.write('<annotation>\n')
        xml_file.write('  <folder>{}</folder>\n'.format(dataSetName))
        xml_file.write('  <filename>{}</filename>\n'.format(pictureName + ".xml"))
        xml_file.write('  <size>\n')
        xml_file.write('    <width>{}</width>\n'.format(width))
        xml_file.write('    <height>{}</height>\n'.format(height))
        xml_file.write('    <depth>{}</depth>\n'.format(depth))
        xml_file.write('  </size>\n')

        for objects in objectList:
            xml_file.write('  <object>\n')
            xml_file.write('    <type>{}</type>\n'.format('robndbox'))
            xml_file.write('    <name>{}</name>\n'.format('ship'))
            xml_file.write('    <difficult>{}</difficult>\n'.format('0'))
            xml_file.write('    <robndbox>\n')
            xml_file.write('      <cx>{}</cx>\n'.format(objects[0]["cx"]))
            xml_file.write('      <cy>{}</cy>\n'.format(objects[0]["cy"]))
            xml_file.write('      <h>{}</h>\n'.format(objects[0]["h"]))
            xml_file.write('      <w>{}</w>\n'.format(objects[0]["w"]))
            xml_file.write('      <angle>{}</angle>\n'.format(objects[0]["angle"]))
            xml_file.write('    </robndbox>\n')
            xml_file.write('  </object>\n')
        xml_file.write('</annotation>')

        这段代码会遍历objectList数组,将其中的信息都写入到object标签下的对应小标签中。

四、总结

        感觉网上关于pkl文件解析的相关文章比较少(也许是我搜集资料的能力有限),希望能够帮到大家!有哪些不足之处也欢迎大家提出!

  • 14
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值