本科阶段最后一次竞赛Vlog——2024年智能车大赛智慧医疗组准备全过程——6Resnet实现黑线识别
比赛还有重要部分就是黑线的识别,这块地平线社区的帖子很多
在本次我就使用了社区吴超大佬写出的文章,当然我们的步骤有所不同,也是比较省事的一种
现在和大家一起聊聊我们的应用
1.代码介绍
社区上的代码主要这几个py文件,这里我们数据制作采用了自己的,只是简单的opencv
2.准备工作
准备工作就是创建好数据集
这里建议直接运行超的代码
DATASET_NMAE = "DataSet_3_1119"
from os import makedirs
path = "./" + DATASET_NMAE + "/"
try:
makedirs(path + "train/image/")
makedirs(path + "train/label/")
makedirs(path + "test/image/")
makedirs(path + "test/label/")
print("Dirs Success")
except:
print("Dirs Failed")
3. 录制视频
首先第一步是,获取512*512大小的图片,这里给大家我们是使用Opencv 进行录制视频,进行使用代码裁剪
对于小车打开摄像头选项是8,这里按下Ctrl C就可以自动保存
import cv2
import signal
import sys
def signal_handler(sig, frame):
global stop_recording
stop_recording = True
def record_video(output_file, width=640, height=480, fps=30):
global stop_recording
stop_recording = False
# 创建视频捕获对象
fourcc = cv2.VideoWriter_fourcc(*'XVID') # 使用XVID编码器
out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
# 打开默认摄像头
cap = cv2.VideoCapture(8)
# 设置视频帧率
cap.set(cv2.CAP_PROP_FPS, fps)
# 获取摄像头的实际帧率(以确保设置成功)
actual_fps = cap.get(cv2.CAP_PROP_FPS)
print(f"实际帧率: {actual_fps}")
# 注册信号处理函数以便捕获Ctrl+C
signal.signal(signal.SIGINT, signal_handler)
# 开始录制
while True:
ret, frame = cap.read()
if not ret:
break
# 写入当前帧到视频文件
out.write(frame)
# 检查是否收到停止信号
if stop_recording:
print("录制已结束")
break
# 按 'q' 键退出
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# 释放资源
cap.release()
out.release()
cv2.destroyAllWindows()
# 调用函数
record_video('output.avi', width=640, height=480, fps=30)
4. 01_get512img
针对小车录制的视频,保存到本地,使用下面代码进行抽帧
代码里已经给大家写好注释了
from datetime import datetime
import cv2
import os
def extract_frames(video_path, output_folder, i,frame_size=(512, 512) ,skip_frames=3):
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("Error: Could not open video.")
return
frame_count = 0
saved_frame_count = 0
current_time = datetime.now().strftime("%Y%m%d_%H%M%S") # 获取当前时间并格式化为字符串
while True:
# 读取视频中的一帧
ret, frame = cap.read()
if not ret:
break # 视频结束或读取错误
# 每隔skip_frames帧保存一帧
if frame_count % skip_frames == 0:
# 调整帧的大小
resized_frame = cv2.resize(frame, frame_size, interpolation=cv2.INTER_AREA)
# 保存帧到文件,文件名前加上当前时间
frame_filename = f"{output_folder}/{i}_frame_{saved_frame_count:04d}.png"
cv2.imwrite(frame_filename, resized_frame)
print(f"Saved {frame_filename}")
saved_frame_count += 1
frame_count += 1
# 释放视频捕获对象
cap.release()
print("Done extracting frames.")
# 使用示例
i = 1
while i<=5:
video_path = rf'D:\cardata\{i}.avi'
output_folder = f'./DataSet_3_1119/train/image'
extract_frames(video_path, output_folder, i,frame_size=(512, 512), skip_frames=5) # 每隔5帧保存一次
# extract_frames(video_path, output_folder, i,frame_size=(672, 672), skip_frames=5) # 每隔5帧保存一次
i+=1
5. 进行标记
如果上面你按照我粘贴的代码,现在这里就可以无脑运行超哥的代码
import cv2
import os
DATASET_NMAE = "DataSet_3_1119" # 数据集名称
ZOOM = 1.0 # 显示缩放倍数,与标注数据无关,仅仅适应一些高分屏的电脑
def get_xy(file_path):
# 读取txt文件,获取x,y坐标(浮点数表示)
x,y = -1.0,-1.0
with open(file_path) as f:
content = f.read().split(" ")
x, y = float(content[0]), float(content[1])
f.close()
return x,y
def mouse_callback(event, x, y, flags, param):
# 鼠标点击事件
global img_x, img_y, label_path, txt_name,img,img_width, img_height
if event == cv2.EVENT_LBUTTONUP:
print(img_width,img_height)
img_x, img_y = float(x)/img_width/ZOOM, float(y)/img_height/ZOOM
cv2.imshow("img", cv2.circle(img.copy(), (x, y), 10,(0,0,255), -1))
print("Mouse Click(%d, %d), Save as(%.8f, %.8f)"%(x,y,img_x,img_y))
with open(label_path + txt_name,"w") as f:
f.write("%.8f %.8f"%(img_x, img_y))
# 新建cv2的工作窗口,并绑定鼠标点击的回调函数
img_x, img_y = -1,-1
cv2.namedWindow('img')
cv2.setMouseCallback('img', mouse_callback)
img_path = DATASET_NMAE + "/train/image/"
label_path = DATASET_NMAE + "/train/label/"
print("img path = %s"%img_path)
print("label path = %s"%label_path)
img_names = os.listdir(img_path)
# img size
img_width, img_height = 0, 0
# img control
img_control = 0
img_control_min = 0
img_control_max = len(img_names) - 1
while True:
name = img_names[img_control]
print(name, end=" ")
img = cv2.imread(img_path + name)
img_height, img_width = img.shape[:2]
img = cv2.resize(img, (0,0), fx=ZOOM, fy=ZOOM)
cv2.imshow("img", img)
print("height = %d, width = %d"%(img_height, img_width), end=" ")
## 若存在标签则绘制点,若不存在则不绘制
txt_name = name.split(".")[0] + ".txt"
label_names = os.listdir(label_path)
if txt_name in label_names:
img_x, img_y = get_xy(label_path + txt_name)
cv2.imshow("img", cv2.circle(img.copy(), (int(ZOOM*img_width*img_x), int(ZOOM*img_height*img_y)), 10,(0,0,255), -1))
# print(int(ZOOM*img_width*img_x), int(ZOOM*img_height*img_y))
print("\033[32;40m" + "Label Exist" + "\033[0m" + ": x = %.8f, y = %.8f"%(img_x, img_y))
else:
print("\033[31m" + "NO Label" + "\033[0m")
## while 循环的控制
command = cv2.waitKey(0) & 0xFF
# 慢速退
if command == ord('a'):
if img_control > img_control_min:
img_control -= 1
else:
img_control = 0
print("First img already")
# 慢速进
elif command == ord('d'):
if img_control < img_control_max:
img_control += 1
else:
img_control = img_control_max
print("Last img already")
# 快速退
elif command == ord('z'):
if img_control - 4 > img_control_min:
img_control -= 5
else:
img_control = 0
print("First img already")
# 快速进
elif command == ord('c'):
if img_control + 4 < img_control_max:
img_control += 5
else:
img_control = img_control_max
print("Last img already")
# 退出
elif command == ord('q'):
break
else:
print("Unknown Command")
6. 删除多余标签
上面打完标签,会生成很多txt,但是对于有些时候有些图片并没有黑线
这里就会导致图片与txt不匹配,这里给大家写了个删除多余图片和标记的代码
import os
def delete_unlabeled_images(image_dir, label_dir):
# 遍历图片目录中的所有文件
for img_file in os.listdir(image_dir):
if img_file.endswith(".png"): # 确保处理的是PNG图片
# 构建对应的标签文件路径
label_file = os.path.splitext(img_file)[0] + ".txt"
label_path = os.path.join(label_dir, label_file)
# 检查标签文件是否存在
if not os.path.exists(label_path):
# 如果标签文件不存在,则删除图片
img_path = os.path.join(image_dir, img_file)
os.remove(img_path)
print(f"Deleted image without label: {img_path}")
# 数据集名称
DATASET_NAME = "./taSet_3_1119"
# 图片和标签的目录路径
img_path = os.path.join(DATASET_NAME, "train/image")
label_path = os.path.join(DATASET_NAME, "train/label")
# 调用函数,删除没有对应标签的图片
delete_unlabeled_images(img_path, label_path)
7. 划分数据
上面打完标签,会生成很多txt,当然现在是仅仅在DataSet_3_1119/train/image这个目录
我们现在需要划分一部分图片和对应标签到test里面
按照上面的这里仍然可以无脑进行下一步
DATASET_NMAE = "DataSet_3_1119" # 数据集名称
test_percent = 0.25 # 0.25表示25%的图片作为测试集
from random import sample
from shutil import move
from os import listdir
path = "./" + DATASET_NMAE + "/"
train_image = "train/image/"
train_label = "train/label/"
test_image = "test/image/"
test_label = "test/label/"
images_names = listdir(path + train_image)
# 抽样并移动
test_number = int(len(images_names)*test_percent)
test_names = sample(images_names, test_number)
for name in test_names:
# 移动图片
image_old = path + train_image + name
image_path = path + test_image + name
print(image_old, end=" ")
try:
move(image_old,image_path)
print("\033[32;40m" + "Success." + "\033[0m")
except:
print("\033[31m" + "Failed! " + "\033[0m")
# 移动标签
label_old = path + train_label + name.split(".")[0] + ".txt"
label_path = path + test_label + name.split(".")[0] + ".txt"
print(label_old, end=" ")
try:
move(label_old, label_path)
print("\033[32;40m" + "Success." + "\033[0m")
except:
print("\033[31m" + "Failed! " + "\033[0m")
8.训练
这个时候,就可以真正训练了
## 此Python脚本在开发机上运行 ##
# Step 5
# 训练ResNet18
# 如果是第一次训练会自动下载预训练权重,约40MB
# 训练结束后会在当前目录下生成一个名为BEST_MODEL_PATH的模型文件
# CPU就能训练,我的R7-4800H约12秒一个Epoch,不会太慢
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np
from threading import Thread
from time import time, sleep
DATASET_NMAE = "DataSet_3_1119" # 数据集名称
BEST_MODEL_PATH = './model_best1000.pth' # 最好的训练结果
BATCH_SIZE = 256
NUM_EPOCHS = 1000 # 迭代次数
def main(args=None):
best_loss = 1e9
train_image = "./" + DATASET_NMAE + "/train/"
test_image = "./" + DATASET_NMAE + "/test/"
train_dataset = XYDataset(train_image)
test_dataset = XYDataset(test_image)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
# 创建ResNet18模型,这里选用已经预训练的模型,
# 更改fc输出为2,即x、y坐标值
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(512, 2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = model.to(device)
optimizer = optim.Adam(model.parameters())
print("开始训练")
for epoch in range(NUM_EPOCHS):
print(epoch)
epoch_time_begin = time()
model.train()
train_loss = 0.0
for images, labels in iter(train_loader):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = F.mse_loss(outputs, labels)
train_loss += float(loss)
loss.backward()
optimizer.step()
train_loss /= len(train_loader)
model.eval()
test_loss = 0.0
for images, labels in iter(test_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = F.mse_loss(outputs, labels)
test_loss += float(loss)
test_loss /= len(test_loader)
msgStr = "Epoch" + "\033[32;40m" + " %d " % epoch + "\033[0m"
msgStr += "-> time: \033[32;40m%.3f\033[0m s, train_loss: \033[32;40m%f\033[0m, test_loss: \033[32;40m%f\033[0m" % (
time() - epoch_time_begin, train_loss, test_loss)
if test_loss < best_loss:
msgStr += (" \033[31m" + " Saved" + "\033[0m")
torch.save(model.state_dict(), BEST_MODEL_PATH)
best_loss = test_loss
else:
msgStr += " Done"
print(msgStr)
class XYDataset(torch.utils.data.Dataset):
def __init__(self, directory, random_hflips=False):
self.directory = directory
self.random_hflips = random_hflips
self.image_paths = glob.glob(os.path.join(
self.directory + "/image", '*.png'))
self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
image = PIL.Image.open(image_path)
with open(os.path.join(self.directory + "/label", os.path.splitext(os.path.basename(image_path))[0]+".txt"), 'r') as label_file:
content = label_file.read()
values = content.split()
if len(values) == 2:
value1 = float(values[0])
value2 = float(values[1])
else:
print("文件格式不正确")
x, y = value1, value2
if self.random_hflips:
if float(np.random.rand(1)) > 0.5:
image = transforms.functional.hflip(image)
x = -x
image = self.color_jitter(image)
image = transforms.functional.resize(image, (224, 224))
image = transforms.functional.to_tensor(image)
image = image.numpy().copy()
image = torch.from_numpy(image)
image = transforms.functional.normalize(image,
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return image, torch.tensor([x, y]).float()
if __name__ == '__main__':
main()
9.转ONNX
直接无脑运行哈哈,超超佬的代码太好用了
!!!当然无脑也得改路径
import torchvision
import torch
BEST_MODEL_PATH = r'C:\Users\jszjg\Desktop\ResNet18\model_000056_all.pth' # 最好的训练结果
def main(args=None):
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512,2)
model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location='cpu'))
device = torch.device('cpu')
model = model.to(device)
model.eval()
x = torch.randn(1, 3, 224, 224, requires_grad=True)
# torch_out = model(x)
torch.onnx.export(model,
x,
BEST_MODEL_PATH[:-4] + ".onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'])
if __name__ == '__main__':
main()
10.总结与下期预告
现在按照地平线大佬的教程,训练一段时间已经可以获得一个onnx模型了
后面将把resnet转模型步骤给大家进行演示