Window中使用FairMOT对KITTI中汽车类进行训练

数据格式的转变

1 将KITTI数据集中的标签文件转化为MOT中gt.txt格式,将Type又字符变为数字表示,空格用逗号替代

原标签文件如下:
原始标签文件
转换代码如下:

import os
import numpy as np
import pandas as pd
import os.path as osp


def replace(file, old_content, new_content):
    content = read_file(file)
    content = content.replace(old_content, new_content)
    rewrite_file(file, content)


# 读文件内容
def read_file(file):
    with open(file, encoding='UTF-8') as f:
        read_all = f.read()
        f.close()

    return read_all


# 写内容到文件
def rewrite_file(file, data):
    with open(file, 'w', encoding='UTF-8') as f:
        f.write(data)
        f.close()


src_data = r'G:\DataSet\data\training\label_02'

seqs = [s for s in os.listdir(src_data)]
# print(seqs):["0000.txt","0001.txt","0002.txt",...]
for seq in seqs:
    path = osp.join(src_data, seq)
    # seq_gt_path = osp.join(src_data, seq, 'gt/gt.txt')
    # print(seq_gt_path)
    # gt = np.loadtxt(seq_gt_path, dtype=np.str, delimiter=',')  # 加载成np格式
    # print(str(gt))
    replace(path, ' ', ',')
    replace(path, 'DontCare', '10')
    replace(path, 'Person', '1')
    replace(path, 'Pedestrian', '2')
    replace(path, 'Car', '3')
    replace(path, 'Person_sitting', '4')
    replace(path, 'Cyclist', '5')
    replace(path, 'Van', '6')
    replace(path, 'Truck', '7')
    replace(path, 'Tram', '8')
    replace(path, 'Misc', '9')

并将转换后的文件复制到对应的训练数据集中:如image2/train/0001/gt/gt.txt中,首先需要建立对应的文件夹,代码如下:

import os
from shutil import copy

def mkdirs(d):
    # if not osp.exists(d):
    if not os.path.isdir(d):
        os.makedirs(d)

seqs = ['0000', '0001', '0002', '0003',
        '0004', '0005', '0006', '0007',
        '0008', '0009', '0010', '0011',
        '0012', "0013", '0014', '0015', "0016", "0017", '0018',
        '0019', '0020']
dst_dir_root = r"G:\DataSet\data\kitti\MOT\images\train"
src_data = r'G:\DataSet\data\training\label_02'
for seq in seqs:
	gt = "gt"
    path = os.path.join(dst_dir_root , seq,gt)
    mkdir(path)
 for root, dirs, files in os.walk(src_data ):
 	for file in files:
 		src_file_path = os.path.join(src_data, file)
 		copy(src_file_path , os.path.join(dst_dir_root, os.path.splitext(file)[0], "gt", "gt.txt"))

最后效果如下:
在这里插入图片描述
标签文件中每列代表的含义

kitti tracking 标签含义:
The label files contain the following information, which can be read and
written using the matlab tools (readLabels.m) provided within this devkit. 
All values (numerical or strings) are separated via spaces, each row 
corresponds to one object. The 17 columns represent:

#Values    Name      Description
----------------------------------------------------------------------------
   1    frame        Frame within the sequence where the object appearers
   1    track id     Unique tracking id of this object within this sequence
   1    type         Describes the type of object: 'Car', 'Van', 'Truck',
                     'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram',
                     'Misc' or 'DontCare'
   1    truncated    Integer (0,1,2) indicating the level of truncation.
                     Note that this is in contrast to the object detection
                     benchmark where truncation is a float in [0,1].
   1    occluded     Integer (0,1,2,3) indicating occlusion state:
                     0 = fully visible, 1 = partly occluded
                     2 = largely occluded, 3 = unknown
   1    alpha        Observation angle of object, ranging [-pi..pi]
   4    bbox         2D bounding box of object in the image (0-based index):
                     contains left, top, right, bottom pixel coordinates
   3    dimensions   3D object dimensions: height, width, length (in meters)
   3    location     3D object location x,y,z in camera coordinates (in meters)
   1    rotation_y   Rotation ry around Y-axis in camera coordinates [-pi..pi]
   1    score        Only for results: Float, indicating confidence in
                     detection, needed for p/r curves, higher is better.

2 加下来根据gt.txt,对每个数据集的每张图片生成对应的标签文件:

对了再次之前自己再每个数据集如0001中仿照MOT的格式制作seqinfo.ini文件
代码如下:

"""
按照FairMOT数据集的格式根据“gt.txt"给每个图片打上对应的标签
kitti数据集标签生成(Car或Van或Truck)
"""

import os.path as osp
import os
import shutil
import numpy as np


def mkdirs(d):
    # if not osp.exists(d):
    if not osp.isdir(d):
        os.makedirs(d)


data_root = r'G:/DataSet/data/kitti/'
seq_root = data_root + 'MOT/images/train'
label_root = data_root + 'MOT/labels_with_ids/train'

if not os.path.isdir(label_root):
    mkdirs(label_root)
else:  # 如果之前已经生成过: 递归删除目录和文件, 重新生成目录
    shutil.rmtree(label_root)
    os.makedirs(label_root)

cls_map = {
    'Person=1'
    'Pedestrian=2'
    'Car=3'
    'Person_sitting=4'
    'Cyclist=5'
    'Van=6'
    'Truck=7'
    'Tram=8'
    'Misc=9'
    'DontCare=10'
}

print("Dir %s made" % label_root)
# seqs = [s for s in os.listdir(seq_root)]
# seqs=['0000']

seqs = ['0000', '0001', '0002', '0003',
        '0004', '0005', '0006', '0007',
        '0008', '0009', '0010', '0011',
        '0012', "0013", '0014', '0015', "0016", "0017", '0018',
        '0019', '0020']
# 打印序列
print(seqs)

tid_curr = 0
tid_last = -1
total_track_id_num = 0  # 计算数据集中的所有跟踪目标的数目
for seq in seqs:  # 每段视频都对应一个gt.txt
    print("Process %s, " % seq, end='')
    # seq_root = data_root + 'MOT/images/train'
    # label_root = data_root + 'MOT/labels_with_ids/train'
    seq_info_path = osp.join(seq_root, seq, 'seqinfo.ini')  # 提取每个数据的info信息 /media/ckq/data/kitti/MOT/images/train
    # print(seq_info_path)
    with open(seq_info_path) as seq_info_h:  # 读取 *.ini 文件
        seq_info = seq_info_h.read()
        seq_width = int(seq_info[seq_info.find('imWidth=') + 8:seq_info.find('\nimHeight')])  # 视频的宽
        seq_height = int(seq_info[seq_info.find('imHeight=') + 9:seq_info.find('\nimExt')])  # 视频的高
        # print('seq_width:',seq_width)
        # print('seq_height:', seq_height)

    gt_txt = osp.join(seq_root, seq, 'gt', 'gt.txt')  # 读取GT文件
    # print(gt_txt)  #打印路径
    # gt = np.loadtxt(gt_txt, dtype=np.str, delimiter=',')  # 加载成np格式

    gt = np.loadtxt(gt_txt, dtype=np.float64, delimiter=',')  # 加载成np格式
    print(gt)  # 打印文本内容
    print('gt.T')
    print(gt.T)  # 也是打印文本内容
    idx = np.lexsort(gt.T[:2, :])  # 优先按照track id排序(对视频帧进行排序, 而后对轨迹ID进行排序)
    # print(idx)
    gt = gt[idx, :]

    tr_ids = set(gt[:, 1])
    print("%d track ids in seq %s" % (len(tr_ids), seq))
    total_track_id_num += len(tr_ids)  # track id统计数量如何正确计算?

    seq_label_root = osp.join(label_root, seq, 'img1')
    mkdirs(seq_label_root)

    # 读取GT数据的每一行(一行即一条数据)
    # for fid, tid, x, y, w, h, mark, cls, vis_ratio in gt:
    for fid, tid, type, truncated, occluded, alpha, \
        bbox_left, bbox_top, bbox_right, bbox_bottom, _, _, _, _, _, _, _ in gt:
        # height, width, length , location_x,location_y,location_z , rotation_y in gt:
        # frame_id, track_id, top, left, width, height, mark, class, visibility ratio
        # if cls != 3:  # 我们需要Car的标注数据
        if type != 3:  # 我们需要Car的标注数据
            continue

        # if mark == 0:  # mark为0时忽略(不在当前帧的考虑范围)
        #     continue

        # if vis_ratio <= 0.2:
        #     continue

        fid = int(fid)
        tid = int(tid)

        # 判断是否是同一个track, 记录上一个track和当前track
        if not tid == tid_last:  # not 的优先级比 == 高
            tid_curr += 1
            tid_last = tid
        # 由于kitti标签与训练标签参数有点不同 需要自己计算 x y w h
        w = float(bbox_right - bbox_left)
        h = float(bbox_bottom - bbox_top)
        x = int(bbox_left + 0.5)
        y = int(bbox_top + 0.5)

        # bbox中心点坐标
        x += w / 2
        y += h / 2

        # 网label中写入track id, bbox中心点坐标和宽高(归一化到0~1)
        # 第一列的0是默认只对一种类别进行多目标检测跟踪(0是类别)
        label_str = '0 {:d} {:.6f} {:.6f} {:.6f} {:.6f}\n'.format(
            tid_curr,
            x / seq_width,  # center_x
            y / seq_height,  # center_y
            w / seq_width,  # bbox_w
            h / seq_height)  # bbox_h
        # print(label_str.strip())

        label_f_path = osp.join(seq_label_root, '{:06d}.txt'.format(fid))
        with open(label_f_path, 'a') as f:  # 以追加的方式添加每一帧的label
            f.write(label_str)

print("Total %d track ids in this dataset" % total_track_id_num)
print('Done')

最终生成标签文件格式如下:
在这里插入图片描述
备注:有些数据集中某些图片中不包含我们需要得汽车类,所以没有生成对应得标签文件,如0001中0-108中不包含汽车标签

3 测试转换后得标签文件是否正确:

效果如下:
在这里插入图片描述

代码如下:

# -*- coding:utf-8 -*-
import os
import cv2
import os.path as osp

'''
显示跟踪训练数据集标注
时间:2020年9月27日
验证转换后的标签是否成功
'''
root_path = r"G:\DataSet\data\kitti\MOT"
img_dir = "images/train"  # 图片数据集位置
label_dir = "labels_with_ids/train"  # label数据集位置

imgs = os.listdir(root_path + "/" + img_dir)  # 遍历图片数据集列表  0000 0001........
imgs.sort()
for i, img in enumerate(imgs):  # 一个一个遍历
    # img_name = img[:-1]   #img[:-1] -1代表从右往左 第一个不取
    # print(img)
    img_name = img  # 每个图片集名字 如:0000
    print(img_name)
    # 例如:G:\DataSet\data\kitti\MOT\labels_with_ids\train\0000\img1
    label_path = osp.join(root_path, label_dir, img_name, 'img1')
    print(label_path)
    # 列出单个子图片集中每个图片对应的标签
    label_gts_name = os.listdir(label_path)
    label_gts_name.sort()
    print(label_gts_name)
    # 对单个子图片数据集的标签文件进行遍历
    for frame_gt in label_gts_name:
        # print(frame_gt)
        frame_gt_name = frame_gt[:6]
        # print(frame_gt_name)
        label_f = open(label_path + "/" + frame_gt_name + ".txt", "r")  # 路劲标签名
        # print(label_f)
        lines = label_f.readlines()
        print(lines)
        # print(root_path + "/" + img_dir + "/" + img+"/img/"+frame_gt_name) 目前为None,目前需要把图片整过来
        img_data = cv2.imread(root_path + "/" + img_dir + "/" + img + "/img1/" + frame_gt_name + ".png")  # gt对应的图片序号
        # print(img_data)
        H, W, C = img_data.shape
        # print(H)
        # print(W)
        # print(C)
        for line in lines:
            line_list = line.strip().split()
            class_num = int(line_list[0])  # 类别号
            obj_ID = int(line_list[1])  # 目标ID
            x, y, w, h = line_list[2:]  # 中心坐标,宽高(经过原图宽高归一化后)
            x = int(float(x) * W)
            y = int(float(y) * H)
            w = int(float(w) * W)
            h = int(float(h) * H)
            left = int(x - w / 2)
            top = int(y - h / 2)
            right = left + w
            bottom = top + h
            cv2.circle(img_data, (x, y), 1, (0, 0, 255))
            cv2.rectangle(img_data, (left, top), (right, bottom), (0, 255, 0), 2)
            cv2.putText(img_data, str(obj_ID), (left, top), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 255), 1)
        resized_img = cv2.resize(img_data, (800, 416))
        cv2.imshow("label", resized_img)
        cv2.waitKey(100)
        print("加载成功")

4 最后生成训练文件

注意需要根据标签文件生成训练文件
代码如下:

import os
import os.path as osp

image_flder = r"G:\DataSet\data\kitti\MOT\labels_with_ids\train" # 标签文件夹,与图片文件夹对应
imgs = os.listdir(image_flder) 
# print(imgs)
train_f = open(r"G:/DataSet/data/kitti/MOT/kitti_car.train", "w")

for img_name in imgs:
    image_path = osp.join(image_flder, img_name, 'img1')
    print(image_path)
    image_names = os.listdir(image_path)  # 各个图片的名字
    image_names.sort()
    print(image_names)
    for image_name in image_names:
        relative_path = "MOT/images/train"
        relative_path = os.path.join(relative_path, img_name, "img1")
        save_str = relative_path + '/' + os.path.splitext(image_name)[0] + ".png" + "\n"
        print(save_str)
        train_f.write(save_str)

train_f.close()

参考链接

FairMOT训练kitti tracking数据集的汽车类

评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值