自己写的的代码,本人水平实在有限欢迎需要的人继续优化(要是能教教我咋改就更好了)目前在我自己电脑上能跑起来,速度能满足需求。
直接上代码:
import cv2
from demo import *
import torch
cap = cv2.VideoCapture(0)
def predict(x):
x_ = transforms.ToTensor()(x)
x1 = transforms.Resize((288, 800))(x_)
x2 = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(x1)
x3 = x2.unsqueeze(0).cuda() + 1
out = net(x3)
# net需要的是(1, 3, 288, 800)
# out(1, 101, 56, 4)
col_sample = np.linspace(0, 800 - 1, cfg.griding_num)
col_sample_w = col_sample[1] - col_sample[0] # 不知道干啥的
out_j = out[0].data.cpu().numpy()
out_j = out_j[:, ::-1, :]
prob = scipy.special.softmax(out_j[:-1, :, :], axis=0)
idx = np.arange(cfg.griding_num) + 1
idx = idx.reshape(-1, 1, 1)
loc = np.sum(prob * idx, axis=0)
out_j = np.argmax(out_j, axis=0)
loc[out_j == cfg.griding_num] = 0
out_j = loc #一样不知道干啥的
vis = x
for i in range(out_j.shape[1]):
if np.sum(out_j[:, i] != 0) > 2:
for k in range(out_j.shape[0]):
if out_j[k, i] > 0:
ppp = (int(out_j[k, i] * col_sample_w * img_w / 800) - 1,
int(img_h * (row_anchor[cls_num_per_lane - 1 - k] / 288)) - 1)
cv2.circle(vis, ppp, 5, (0, 255, 0), -1)
return vis
def windows():
while(cap.isOpened()):
retval, frame = cap.read() # 读进来的图片是(480, 640, 3)ndarry格式
output = predict(frame)
cv2.imshow('Live', output)
if cv2.waitKey(5) >= 0:
break
if __name__ == '__main__':
torch.backends.cudnn.benchmark = True
args, cfg = merge_config()
# dist_print('start testing...')
assert cfg.backbone in ['18', '34', '50', '101', '152', '50next', '101next', '50wide', '101wide']
if cfg.dataset == 'CULane':
cls_num_per_lane = 18
elif cfg.dataset == 'Tusimple':
cls_num_per_lane = 56
else:
raise NotImplementedError
net = parsingNet(pretrained=False, backbone=cfg.backbone, cls_dim=(cfg.griding_num + 1, cls_num_per_lane, 4),
use_aux=False).cuda() # we dont need auxiliary segmentation in testing
state_dict = torch.load(cfg.test_model, map_location='cpu')['model']
compatible_state_dict = {}
for k, v in state_dict.items():
if 'module.' in k:
compatible_state_dict[k[7:]] = v
else:
compatible_state_dict[k] = v
net.load_state_dict(compatible_state_dict, strict=False)
net.eval()
img_w, img_h = 640, 480
row_anchor = tusimple_row_anchor
windows()
设备:Windows系统笔记本电脑
前提是已经配好环境了啊! 好多博主都有复现的教程
操作流程:
1复制这代码
2.在Ultra-Fast-Lane-Detection-master目录下新建一个.py文件
3.粘贴
4.把这行代码加到形参里
configs/tusimple.py --test_model tusimple_18.pth
5.点运行
目前只做了tusimple数据集的实时显示,culane数据集实时显示原理也差不多,稍微改改就能用。
检测结果如下: