【代码阅读】FairMOT多目标跟踪之数据加载(一)

1.了解数据集

【2DMOT15数据集为例】

2.数据集的loader

详细的:dataloader是一个迭代器,在FairMOT数据加载中定义了__getitem__,__next__,__iter__

2.1加载图片

初始化部分:

class LoadImages:  # for inference
    def __init__(self, path, img_size=(1088, 608)):
        if os.path.isdir(path):#给定数据集的地址,做有效性判断
            image_format = ['.jpg', '.jpeg', '.png', '.tif']
            self.files = sorted(glob.glob('%s/*.*' % path))#找到文件夹中的图片,并sorted
            self.files = list(filter(lambda x: os.path.splitext(x)[1].lower() in image_format, self.files))#
        """
        a=filter(function, iterable),当iterable中的元素满足function中的条件,就放到a中
        所以上述:当找到的图片文件的文件名,图片后缀为小写字母?才选择加载到图片list中。
        """
        elif os.path.isfile(path):#如果直接给的是图片地址,那么直接加载进图片的list中去
            self.files = [path]

        self.nF = len(self.files)  # number of image files 图片的总数
        self.width = img_size[0] 
        self.height = img_size[1] # 图片尺寸大小
        self.count = 0

        assert self.nF > 0, 'No images found in ' + path # 如果没有图片,就提示有错误啦

迭代器定义部分:

1.__iter__和__next__合在一起是一种实现方式:
    def __iter__(self):
        self.count = -1 #使用数据集开始将计数器置为-1
        return self

    def __next__(self):
        self.count += 1  #开始图片计数,从0开始
        if self.count == self.nF: #如果迭代完所有的图片,就抛出提醒
            raise StopIteration
        img_path = self.files[self.count]#加载图片的地址

        # Read image
        img0 = cv2.imread(img_path)  # BGR
        assert img0 is not None, 'Failed to load ' + img_path

        # Padded resize使用letterbox进行resize:保持图片的长宽比例,剩下的部分采用pad填充
        img, _, _, _ = letterbox(img0, height=self.height, width=self.width)

        # Normalize RGB 图片转RGB并归一化
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)#?一个内存不连续存储的数组转换为内存连续存储的数组,使得运行速度更快
        img /= 255.0

        # cv2.imwrite(img_path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1])  # save letterbox image
        return img_path, img, img0 #返回图片的地址,resize并归一化的img以及原图

2.__getitem__独立实现一个迭代操作

    def __getitem__(self, idx):
        idx = idx % self.nF
        img_path = self.files[idx]
        #加载图片的步骤跟__next__一样
        # Read image
        img0 = cv2.imread(img_path)  # BGR
        assert img0 is not None, 'Failed to load ' + img_path

        # Padded resize
        img, _, _, _ = letterbox(img0, height=self.height, width=self.width)

        # Normalize RGB
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        img /= 255.0

        return img_path, img, img0

2.2加载视频

初始化部分:

class LoadVideo:  # for inference
    def __init__(self, path, img_size=(1088, 608)):
        #获得视频信息
        self.cap = cv2.VideoCapture(path)
        self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS)))#获得视频帧率
        self.vw = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.vh = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))#获得视频帧的尺寸
        self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))#总的图片帧数

        self.width = img_size[0]
        self.height = img_size[1]
        self.count = 0

        self.w, self.h = 1920, 1080
        print('Lenth of the video: {:d} frames'.format(self.vn))

迭代器定义部分:

    def __iter__(self):
        self.count = -1 #同样的计数器重置
        return self

    def __next__(self):
        self.count += 1
        if self.count == len(self):
            raise StopIteration
        # Read image 开始读视频
        res, img0 = self.cap.read()  # BGR
        assert img0 is not None, 'Failed to load frame {:d}'.format(self.count)
        img0 = cv2.resize(img0, (self.w, self.h)) #resize视频帧

        # Padded resize #不懂为什么要在letterbox之前resize一下
        img, _, _, _ = letterbox(img0, height=self.height, width=self.width)

        # Normalize RGB 一样的通道变换+归一化
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        img /= 255.0

        # cv2.imwrite(img_path + '.letterbox.jpg', 255 * img.transpose((1, 2, 0))[:, :, ::-1])  # save letterbox image
        return self.count, img, img0 #返回当前帧数,当前帧归一化的img 以及原帧

    def __len__(self):
        return self.vn  # number of files 返回视频的总帧数

2.3 加载images和labels

这个类是后续加载数据的继承类,简单的实现了从相应文件中加载图片和标签文件的过程

初始化:

class LoadImagesAndLabels:  # for training
    def __init__(self, path, img_size=(1088, 608), augment=False, transforms=None):
        with open(path, 'r') as file:
            self.img_files = file.readlines()
            self.img_files = [x.replace('\n', '') for x in self.img_files]
            self.img_files = list(filter(lambda x: len(x) > 0, self.img_files))

        #上述操作都是在获得相应的图片相关信息
        #如果是训练的话,那么举个例子--MOT17/images/train/MOT17-02-SDP/img1/000009.jpg
        #list里面存放的应该是这种图片名称

        #当图片名称得到以后就开始准备标签名称,把相应的文件夹替换掉
        #上述例子得到的就应该是--MOT17/labels_with_ids/train/MOT17-02-SDP/img1/000009.txt
        #就是上述图片对应的标签文件
        self.label_files = [x.replace('images', 'labels_with_ids').replace('.png', '.txt').replace('.jpg', '.txt')
                            for x in self.img_files]
        
        #下述操作一样的,统计图片信息,包括图片数目,图片尺寸,以及数据预处理相关信息
        self.nF = len(self.img_files)  # number of image files
        self.width = img_size[0]
        self.height = img_size[1]
        self.augment = augment
        self.transforms = transforms

获取数据的关键

    def get_data(self, img_path, label_path):
        height = self.height
        width = self.width
        img = cv2.imread(img_path)  # BGR
        if img is None:
            raise ValueError('File corrupt {}'.format(img_path))
        #以上是读取图片信息的相关操作
        augment_hsv = True
        #如果需要hsv数据增强
        if self.augment and augment_hsv:
            # SV augmentation by 50%
            fraction = 0.50 
            img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) #将图片转为hsv格式
            S = img_hsv[:, :, 1].astype(np.float32)
            V = img_hsv[:, :, 2].astype(np.float32) #获得S饱和度,V亮度两个通道的相应分量

            a = (random.random() * 2 - 1) * fraction + 1
            S *= a #对饱和度进行操作
            if a > 1:
                np.clip(S, a_min=0, a_max=255, out=S)
                #如果随即因子a大于1,那么可能存在像素值经过变换后大于255,所以要约束一下
            
            #亮度通道一样的操作
            a = (random.random() * 2 - 1) * fraction + 1
            V *= a
            if a > 1:
                np.clip(V, a_min=0, a_max=255, out=V)

            img_hsv[:, :, 1] = S.astype(np.uint8)
            img_hsv[:, :, 2] = V.astype(np.uint8)
            #把随机调整增强的像素,赋给原图,再转回RGB通道放入dst中
            cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)

        h, w, _ = img.shape
        img, ratio, padw, padh = letterbox(img, height=height, width=width)
        #图片进行letterbox操作

        # Load labels
        if os.path.isfile(label_path):
            #读取图片的标签,读进来是一个N*6的矩阵,N是该张图片中存在的目标数目
            labels0 = np.loadtxt(label_path, dtype=np.float32).reshape(-1, 6)

            # Normalized xywh to pixel xyxy format
            #因为刚刚图片做了letterbox操作,并且label中的box是xywh的存放方式
            #所以下述操作是将box进行转换,转换为xyxy的格式并适应letterbox后的图片
            labels = labels0.copy()
            labels[:, 2] = ratio * w * (labels0[:, 2] - labels0[:, 4] / 2) + padw
            labels[:, 3] = ratio * h * (labels0[:, 3] - labels0[:, 5] / 2) + padh
            labels[:, 4] = ratio * w * (labels0[:, 2] + labels0[:, 4] / 2) + padw
            labels[:, 5] = ratio * h * (labels0[:, 3] + labels0[:, 5] / 2) + padh
        else:
            labels = np.array([])

        # Augment image and labels
        if self.augment: #仿射变换的数据增强操作
            img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.50, 1.20))

        plotFlag = False
        if plotFlag: #可视化
            import matplotlib
            matplotlib.use('Agg')
            import matplotlib.pyplot as plt
            plt.figure(figsize=(50, 50))
            plt.imshow(img[:, :, ::-1])
            plt.plot(labels[:, [1, 3, 3, 1, 1]].T, labels[:, [2, 2, 4, 4, 2]].T, '.-')
            plt.axis('off')
            plt.savefig('test.jpg')
            time.sleep(10)

        nL = len(labels) #目标数目的统计
        if nL > 0:
            # convert xyxy to xywh

            """
            def xyxy2xywh(x):
                # Convert bounding box format from [x1, y1, x2, y2] to [x, y, w, h]
                y = torch.zeros(x.shape) if x.dtype is torch.float32 else np.zeros(x.shape)
                y[:, 0] = (x[:, 0] + x[:, 2]) / 2
                y[:, 1] = (x[:, 1] + x[:, 3]) / 2
                y[:, 2] = x[:, 2] - x[:, 0]
                y[:, 3] = x[:, 3] - x[:, 1]
            return y
            转换坐标信息,并归一化
            """
            labels[:, 2:6] = xyxy2xywh(labels[:, 2:6].copy())  # / height
            labels[:, 2] /= width
            labels[:, 3] /= height
            labels[:, 4] /= width
            labels[:, 5] /= height
        if self.augment: #数据增强之图片翻转
            # random left-right flip
            lr_flip = True
            if lr_flip & (random.random() > 0.5):#0.5的概率翻转
                img = np.fliplr(img)
                if nL > 0:#如果该张图片帧里面有目标
                    labels[:, 2] = 1 - labels[:, 2] 把标签信息翻转

        img = np.ascontiguousarray(img[:, :, ::-1])  # BGR to RGB

        if self.transforms is not None:#前预处理操作
            img = self.transforms(img)

        return img, labels, img_path, (h, w)#返回处理好的图片和的对应的标签信息,还有图片尺寸

 

  • 5
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
FairMOT是一种基于深度学习的目标跟踪算法,它采用了一种多任务学习(MTL)的方法来实现目标的准确跟踪。 首先,FairMOT使用一个强大的卷积神经网络(CNN)来提取目标图片的特征。这个网络通过多层卷积和池化等操作来逐步抽取图片中目标的高级特征信息。这些特征能够表达目标的形状、纹理、边缘等重要信息,从而帮助算法进行目标的识别和跟踪。 接下来,FairMOT使用一个匈牙利算法来将帧与帧之间的目标进行匹配。具体来说,匈牙利算法通过计算每个目标之间的相似度得分,并根据最小权重匹配的原则来确定每个目标在不同帧之间的对应关系。这样一来,就可以在连续的帧中准确地跟踪目标的位置和运动轨迹。 为了提高跟踪的鲁棒性和准确性,FairMOT还采取了一个多任务学习(MTL)的策略。这意味着在网络中有多个并行的任务。除了目标跟踪之外,还有目标检测、姿态估计等任务。这种设计的好处是,不同任务之间可以相互促进,使得网络能够更好地理解目标的特征和运动规律,从而提高跟踪的效果。 综上所述,FairMOT通过使用卷积神经网络提取目标的特征,采用匈牙利算法进行目标的匹配,同时引入多任务学习的策略,来实现对目标的准确跟踪。这种方法在一些比赛和实验中已经取得了非常好的效果,对于视频监控、行人追踪等领域具有很大的应用潜力。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值