pytorch版本PSEnet训练并部署

概述

源码地址
torch版本
训练环境没有按照torch的readme一样的环境,自己部署环境为:

torch==1.9.1
torchvision==0.10.1
python==3.8.0
cuda==10.2
mmcv==0.2.12
editdistance==0.5.3
Polygon3==3.0.9.1
pyclipper==1.3.0
opencv-python==3.4.2.17
Cython==0.29.24
./compile.sh

制作数据集

  1. 训练的数据集,采用的是rolabelimg进行标注,需要转换为ic2015格式的数据。转换代码:
import os
from lxml import etree
import numpy as np
import math

src_xml = "ANN"
txt_dir = "gt"

xml_listdir = os.listdir(src_xml)

xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir]

def xml_out(xml_path):
    gt_lines = []
    ET = etree.parse(xml_path)
    objs = ET.findall("object")
    for ix,obj in enumerate(objs):
        name = obj.find("name").text
        robox = obj.find("robndbox")
        cx = int(float(robox.find("cx").text))
        cy = int(float(robox.find("cy").text))
        w = int(float(robox.find("w").text))
        h = int(float(robox.find("h").text))
        angle = float(robox.find("angle").text)
        # angle = math.degrees(angle1)


        wx1 = cx - int(0.5 * w)
        wy1 = cy - int(0.5 * h)

        wx2 = cx + int(0.5 * w)
        wy2 = cy - int(0.5 * h)

        wx3 = cx - int(0.5 * w)
        wy3 = cy + int(0.5 * h)

        wx4 = cx + int(0.5 * w)
        wy4 = cy + int(0.5 * h)

        x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx)
        y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy)

        x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx)
        y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy)

        x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx)
        y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy)

        x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx)
        y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy)

        lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\
                str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n"
        gt_lines.append(lines)
        return gt_lines

def main():
    count = 0
    for xml_dir in xml_listdir:
        gt_lines = xml_out(os.path.join(src_xml,xml_dir))
        txt_path = "gt_" + xml_dir[:-4] + ".txt"
        with open(os.path.join(txt_dir,txt_path),"a+") as fd:
            fd.writelines(gt_lines)
        count +=1
        print("Write file %s" % str(count))

if __name__ == "__main__":
    main()

rolabelimg标注后的xml文件和labelimg的xml有些区别,根据不同的标注软件,转换代码略有区别。转换后的格式为x1,y1,x2,y2,x3,y3,x4,y4,"classes",此处classes为检测的类别,如果是模糊训练的话,classes为“###”。但是重点,这个源代码对于模糊训练,loss一直为1。
2、将数据集分成训练集和测试集
数据集
这里可以按照源码路径存放数据集,也可以修改源码存放位置。PSENet-python3\dataset\psenet\psenet_ic15.py
修改下述代码为自己文件夹
在这里插入图片描述
3、训练

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py

其中根据源码中的readme,
在这里插入图片描述
可以根据自己的需要,自行选择配置文件。

在这里插入图片描述
3、部署测试

import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
import cv2
from torchvision import transforms
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter


def prepare_image(image, target_size):
    """Do image preprocessing before prediction on any data.

    :param image:       original image
    :param target_size: target image size
    :return:
                        preprocessed image
    """

    #assert os.path.exists(img), 'file is not exists'
    #img = cv2.imread(img)
    img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # h, w = image.shape[:2]

    # scale = long_size / max(h, w)
    img = cv2.resize(img, target_size)
    # 将图片由(w,h)变为(1,img_channel,h,w)
    tensor = transforms.ToTensor()(img)
    tensor = tensor.unsqueeze_(0)

    tensor = tensor.to(torch.device("cuda:0"))
    return tensor

def report_speed(outputs, speed_meters):
    total_time = 0
    for key in outputs:
        if 'time' in key:
            total_time += outputs[key]
            speed_meters[key].update(outputs[key])
            print('%s: %.4f' % (key, speed_meters[key].avg))

    speed_meters['total_time'].update(total_time)
    print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))

def load_model(cfg):

    model = build_model(cfg.model)
    model = model.cuda()
    model.eval()
    checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar"
    if checkpoint is not None:
        if os.path.isfile(checkpoint):
            print("Loading model and optimizer from checkpoint '{}'".format(checkpoint))
            sys.stdout.flush()

            checkpoint = torch.load(checkpoint)

            d = dict()
            for key, value in checkpoint['state_dict'].items():
                tmp = key[7:]
                d[tmp] = value
            model.load_state_dict(d)
        else:
            print("No checkpoint found at")
            raise

        # fuse conv and bn
    model = fuse_module(model)
    return model


if __name__ == '__main__':
    src_dir = "testimg/"
    save_dir = "test_save/"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py")
    for d in [cfg, cfg.data.test]:
        d.update(dict(
            report_speed=False
        ))
    if cfg.report_speed:
        speed_meters = dict(
            backbone_time=AverageMeter(500),
            neck_time=AverageMeter(500),
            det_head_time=AverageMeter(500),
            det_pse_time=AverageMeter(500),
            rec_time=AverageMeter(500),
            total_time=AverageMeter(500)
        )
    model = load_model(cfg)
    model.eval()
    count = 0
    for img_name in os.listdir(src_dir):
        img = cv2.imread(src_dir + img_name)
        tensor = prepare_image(img, target_size=(1376, 1024))
        data = dict()
        img_metas = dict()

        data['imgs'] = tensor
        img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]])
        img_metas['img_size'] = torch.tensor([[1376, 1024]])
        data['img_metas'] = img_metas
        data.update(dict(
            cfg=cfg
        ))
        with torch.no_grad():
            outputs = model(**data)

        if cfg.report_speed:
            report_speed(outputs, speed_meters)
        for bboxes in outputs['bboxes']:
            x1 = bboxes[0]
            y1 = bboxes[1]
            x2 = bboxes[4]
            y2 = bboxes[5]
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
        count = count + 1
        cv2.imwrite(save_dir + img_name, img)
        print("img test:", count)
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter

训练代码里含有。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

__JDM__

希望能得到您的打赏

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值