前言
项目需要,做了一小段时间的车道线检测,复现了《Ultra Fast Structure-aware Deep Lane Detection》论文中的开源工程,记录到这篇博文中!
原理简述
在这篇论文中把车道线检测转看作行分类任务:首先设置一系列anchor,这些anchor代表y坐标值(先验信息),再由模型判断每个y值的哪一行属于车道线,获得x坐标值;最终由一些列的点(x,y)构成一条完整的车道线。更详细的论文解读可以参考原作者的博文Link。
模型训练复现
模型重训练按照作者提供的gayhub地址的参考操作流程非常快就能开启你自己的炼丹之路。
原作者定义的网络模型高达200+M,对于车道线检测任务来说这个模型太过庞大,因此需要对网络结构进行魔改,参考博文Link可以非常快速的了解这个工程,并设计专属的网络结构。
除此之外,还有非常重要的一环没有没有看到其他人做过,就是生成这个工程所需要的标签文件(以CULane数据为例):这个工程的标签文件就是与原图分辨率一致的图像,如下图所示:
原图:
标签图:由于像素值设置太小(第0条车道线对应像素值为1、第1条车道线对应像素值为2),因此肉眼无法看出
标签可视化:
还有一个需要注意的点,本车道的左右两条车道线永远对应2和3标签
CULane数据标签制作源码,设置好root路径并提前建好文件夹即可
import os
import cv2
import numpy as np
root = "Y:/data/CULane"
def draw(LabelPath, Idx):
LabelImg = np.zeros((590, 1640), dtype=np.int8)
if Idx == [0, 0, 0, 0]:
return LabelImg
else:
id = 0
color = 0
with open(LabelPath, 'r') as f:
Pst = f.readlines()
for i in range(len(Pst)):
Lane = Pst[i].strip("\n").split(" ")
# print(Lane)
l = []
for j in range(0, len(Lane) - 1, 2):
px = float(Lane[j])
py = float(Lane[j+1])
if py < 0.:
continue
else:
l.append([int(px), int(py)])
l = np.array(l)
for k in range(id, len(Idx)):
if Idx[k] != 0:
color = Idx[k]
break
else:
continue
LabelImg = cv2.polylines(LabelImg, [l], False, color, 16)
id = color
return LabelImg
with open(root + "/list/train_gt.txt", "r") as f:
trainlist = f.readlines()
for i in range(len(trainlist)):
print(i)
Idx = [0, 0, 0, 0]
ImgPath = root + trainlist[i].split(" ")[0]
LableName = trainlist[i].split(" ")[0].split("/")[-1].split(".")[0] + ".lines.txt"
LabelPath = root + trainlist[i].split(" ")[0][0:-9] + LableName
# print(ImgPath)
# print(LabelPath)
if trainlist[i].split(" ")[2] == '1':
Idx[0] = 1
else:
Idx[0] = 0
if trainlist[i].split(" ")[3] == '1':
Idx[1] = 2
else:
Idx[1] = 0
if trainlist[i].split(" ")[4] == '1':
Idx[2] = 3
else:
Idx[2] = 0
if trainlist[i].split(" ")[5].strip('\n') == '1':
Idx[3] = 4
else:
Idx[3] = 0
LabelImg = draw(LabelPath, Idx)
cv2.imwrite(root + "/laneseg_label_w16" + trainlist[i].split(" ")[0][0:-4] + ".png", LabelImg)
print(root + "/laneseg_label_w16" + trainlist[i].split(" ")[0][0:-4] + ".png")
# print("*********************")
onnx模型导出源码
import torch
import torch.onnx
from model.model import parsingNet
from torchsummary import summary
# model = torch.load("./log/all_model.pth", map_location=torch.device('cpu'))
#
# with torch.no_grad():
# model.eval()
# summary(model, (3,288,800))
# 65,18,4仅针对我自己模型
net = parsingNet(pretrained=False, backbone="18", cls_dim=(65, 18, 4), use_aux=False).cpu()
state_dict = torch.load("./out/1103/UFAST_CULane_1103.pth", map_location='cpu')['model']
compatible_state_dict = {}
for k, v in state_dict.items():
if 'module.' in k:
compatible_state_dict[k[7:]] = v
else:
compatible_state_dict[k] = v
net.load_state_dict(state_dict) # 加载模型
net.eval()
dummy_input = torch.randn(1, 3, 288, 800, device='cpu')
torch.onnx.export(net, dummy_input, "./out/1103/UFAST_1103.onnx", verbose=True, input_names=["input"], output_names=["output"])
下篇文章会详细介绍如何基于ncnn使用该模型,OVER!