【Lane】Ultra-Fast-Lane-Detection(2)自定义模型测试

引言

笔者Ultra-Fast-Lane-Detection专栏链接🔗导航:

  1. 【Lane】 Ultra-Fast-Lane-Detection 复现
  2. 【Lane】Ultra-Fast-Lane-Detection(1)自定义数据集训练
  3. 【Lane】Ultra-Fast-Lane-Detection(2)自定义模型测试

该博客实现自定义模型的测试(首先参考链接2实现模型的训练)

4 模型测试

创建脚本文件demo_custom.py

"""
2022.04.20
author:alian
车道线检测
测试自定义的数据集,并保存成检测结果图
H,W:原图尺寸;h:行锚框数,w:单元格数,C:车道线数
"""
# 导入项目源码中的文件
from model.model import parsingNet
from utils.dist_utils import dist_print
from data.constant import tusimple_row_anchor
# 导入库
import scipy.special, tqdm
import torchvision.transforms as transforms
from PIL import Image
import os,glob,cv2,argparse
import numpy as np
import torch.utils.data


class TestDataset(torch.utils.data.Dataset):  # 加载测试数据集----------------------------------------------------------
    def __init__(self, path, img_transform=None):
        super(TestDataset, self).__init__()
        self.path = path
        self.img_transform = img_transform
        self.img_list = glob.glob('%s/*.jpg'%self.path)

    def __getitem__(self, index):
        name = glob.glob('%s/*.jpg'%self.path)[index]
        img = Image.open(name)

        if self.img_transform is not None:
            img = self.img_transform(img)
        return img, name

    def __len__(self):
        return len(self.img_list)


def parse_opt():  # 参数指定-------------------------------------------------------------------------------------------
    parser = argparse.ArgumentParser()
    parser.add_argument('--backbone', type=str, default='18', help='骨干网络')
    parser.add_argument('--model', type=str, default='', help='模型路径')  # 设置
    parser.add_argument('--dataset', type=str, default='dataset', help='数据集名称')
    parser.add_argument('--source', type=str, default=' ', help='测试路径')  # 设置
    parser.add_argument('--savepath', type=str, default=' ', help='保存路径')  # 设置
    parser.add_argument('--save_video', type=bool, default=False, help='保存为视频')
    parser.add_argument('--griding_num', type=int, default=100, help='网格数')
    parser.add_argument('--num_row_anchors', type=int, default=56, help='锚框行')
    parser.add_argument('--num_lanes', type=int, default=2, help='车道数')
    opt = parser.parse_args()
    return opt


# 执行测试---------------------------------------------------------------------------------------------------------------
def run(opt):
    dist_print('start testing...')
    backbone,model,dataset,source,savepath = opt.backbone,opt.model,opt.dataset,opt.source,opt.savepath
    save_video,griding_num,num_row_anchors,num_lanes = opt.save_video,opt.griding_num,opt.num_row_anchors,opt.num_lanes
    assert opt.backbone in ['18', '34', '50', '101', '152', '50next', '101next', '50wide', '101wide']  # 残差网络骨干
    # 网络解析(griding_num:网格数;num_row_anchors:锚框行;num_lanes:车道数)
    net = parsingNet(pretrained=False, backbone=backbone, cls_dim=(griding_num + 1, num_row_anchors, num_lanes),
                     use_aux=False).cuda()
    state_dict = torch.load(model, map_location='cpu')['model']
    compatible_state_dict = {}
    for k, v in state_dict.items():
        if 'module.' in k:
            compatible_state_dict[k[7:]] = v
        else:
            compatible_state_dict[k] = v
    net.load_state_dict(compatible_state_dict, strict=False)
    net.eval()
    # 图像格式统一:(288, 800),图像张量,归一化
    img_transforms = transforms.Compose([
        transforms.Resize((288, 800)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    # 自定义数据集
    datasets = TestDataset(source, img_transform=img_transforms)
    img_w, img_h = 1920, 1080
    row_anchor = tusimple_row_anchor

    for dataset in zip(datasets):  # splits:图片列表 datasets:统一格式之后的数据集
        loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)  # 加载数据集
        if save_video:
            fourcc = cv2.VideoWriter_fourcc(*'MJPG')
            vout = cv2.VideoWriter(dataset + '.avi', fourcc, 30.0, (img_w, img_h))  # 保存结果为视频文件
        else:vout=None
        for i, data in enumerate(tqdm.tqdm(loader)):  # 进度条显示进度
            imgs, names = data  # imgs:图像张量,图像相对路径:
            imgs = imgs.cuda()  # 使用GPU
            with torch.no_grad():  # 测试代码不计算梯度
                pred = net(imgs)  # 模型预测 输出张量:[1,101,56,C]
            # 解析预测结果-----------------------------------------------------------------------------------------------
            out_j = pred[0].data.cpu().numpy()  # 数据类型转换成numpy [101,56,C]
            out_j = out_j[:, ::-1, :]  # 将第二维度倒着取[101,56,C]
            prob = scipy.special.softmax(out_j[:-1, :, :], axis=0)  # [100,56,C] softmax 计算(概率映射到0-1之间且沿着维度0概率总和=1)
            idx = np.arange(griding_num) + 1  # 产生 1-100
            idx = idx.reshape(-1, 1, 1)  # [100,1,1]
            loc = np.sum(prob * idx, axis=0)  # [56,C]
            out_j = np.argmax(out_j, axis=0)  # 返回最大值的索引
            loc[out_j == griding_num] = 0  # 若最大值的索引=100,则说明改行为背景,不存在车道线,归零
            out_j = loc  # [56,4]

            # 将特征图上的车道线像素坐标映射到原始图像中--------------------------------------------------------------------
            grids = np.linspace(0, 800 - 1, griding_num)  # 单元格的分布
            grid = grids[1] - grids[0]  # 单元格的间隔
            img = cv2.imdecode(np.fromfile(os.path.join(source, names[0]), dtype=np.uint8),
                               cv2.IMREAD_COLOR)  # 图像读取 (1080,1920,3)
            list_point = []  # 车道线关键像素
            for i in range(out_j.shape[1]):  # C 车道线数
                dots = []
                if np.sum(out_j[:, i] != 0) > 2:  # 车道线像素数大于2
                    for k in range(out_j.shape[0]):  # 遍历行row_anchor:56
                        if out_j[k, i] > 0:
                            point = (int(out_j[k, i] * grid * img_w / 800) - 1,
                                     int(img_h * (row_anchor[opt.num_row_anchors - 1 - k] / 288)) - 1)
                            cv2.circle(img, point, 5, (0, 0, 255), -1)  # 在原始图像描述关键点

            if save_video:
                vout.write(img)  # 保存视频结果
            else:
                # 保存检测结果图
                cv2.imwrite(os.path.join(savepath, os.path.basename(names[0])), img)
        if save_video:vout.release()

if __name__ == "__main__":
	import torch.backends.cudnn
    torch.backends.cudnn.benchmark = True  # 加速
    opt = parse_opt()  # 指定参数
    run(opt)

终端运行测试指令

python demo_custom --model model_path --source img_path --savepath output

测试结果如下图所示
请添加图片描述

5 总结

笔者自定的代码如下:

Ultra-Fast-Lane-Detection-master   # 项目源码
├─utils_alian  # 自定义文件夹
     ├─check_label_num.py  # 检查标签数量
     ├─config.py  # 训练参数配置
     ├─dataloader_alian.py  # 数据加载器
     ├─dataset_alian.py  # 数据集构建
     └─instance.py  # 构建实例数据
├─demo_custom.py  # 自定义模型测试代码
│      
└─train_alian.py  # 训练代码

其余的直接调用项目中的源码文件
笔者将上述文件分享再百度云盘中,需要的可以自行下载,但一定要放在指定的位置哦!
百度云盘分享链接
提取码:UFLD

  • 6
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 17
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值