OSnet代码复现

  最近小组合作做一个项目,我刚好负责Reid部分, 所以找到了CVPR上的这篇osnet来复现一下,代码可以从链接自取。:https://github.com/KaiyangZhou/deep-person-reid

代码叉下来之后先看里面的readme文件,里面已经写的很详细了。如何建环境,这里要插一嘴的是在安装torch的时候它没有指定版本torch和cuda的版本,torch版本过高很容易出现其他意想不到的错误,亲测了torch1.0.2是可以运行的,cuda查看自己对应的版本安装就好。

   另外还有一个有点坑的是作者的request.txt文件的包没有指定版本,所以很多包版本过高,这个根据提示消息按要求安装对应的版本就行。

进入第一个get started,在终端输入python,然后导入对应的torchreid的包,然后回车,后面的参数给定也可以用这种方法,也可以将这些参数都写一个sh文件,然后在终端用sh.demo.sh来命令,效果是一样的。

  然后在数据集market1501上来训练自己的模型,这里我用的是上面提供的第二种方法,

将python后面的都复制进一个txt文件中,这里的path参数需要自己更改,将自己存放martket1501数据集的绝对路径复制到这里即可。然后在终端运行demo.sh开始加载自己的模型。

  重点来了,运行的是scripts下的main文件,加载数据集一直都ok,到最后engine.run的时候就出错了,如下:

 网上百度不到这个错误,给的都是改网络的配置文件,这里是python的包出了问题,看最开始报错的地方开始看发现是building model的时候出现的错误,在pycharm上回到main中追一下building model看一下这个函数:

 这里发现模型使用了预训练,打印一下pretrained的内容为true,然后打印model.pretrain发现加载的预训练模型为osnet_x1_0,在building model下的字典中查看到了这个模型:

 ctrl追一下这个函数,看见如果预训练就就将模型赋值

 接着追一下init_pretrained_weights这个函数

def init_pretrained_weights(model, key=''):
    """Initializes model with pretrained weights.
    
    Layers that don't match with pretrained layers in name or size are kept unchanged.
    """
    import os
    import errno
    import gdown
    from collections import OrderedDict

    def _get_torch_home():
        ENV_TORCH_HOME = 'TORCH_HOME'
        ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
        DEFAULT_CACHE_DIR = '~/.cache'
        torch_home = os.path.expanduser(
            os.getenv(
                ENV_TORCH_HOME,
                os.path.join(
                    os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
                )
            )
        )
        return torch_home

    torch_home = _get_torch_home()
    model_dir = os.path.join(torch_home, 'checkpoints')
    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise
    filename = key + '_imagenet.pth'
    cached_file = os.path.join(model_dir, filename)

    if not os.path.exists(cached_file):
        gdown.download(pretrained_urls[key], cached_file, quiet=False)

  代码最后显示如果path不存在配置文件,就调用下面的链接去下载!根据这个路径的命令,看出预训练文件应该放在.catch文件夹下,用vnc显示隐藏文件发现确实没有!然后它就去用这个链接下载了,点进连接看发现:

这是一个需要访问谷歌的网站,这也是为什么报网络错误的原因!!

 解决方法:点进链接将配置文件下载下来然后放到对应的文件下就可以了,然后运行sh脚本文件运行程序,可以发现代码已经开始训练了。

 

### OC-SORT 算法实现教程 OC-Sort 是一种改进版的多目标跟踪算法,旨在提升在线实时跟踪的效果和准确性。该算法继承了 Simple Online and Realtime Tracking (SORT) 的核心思想并进行了优化[^1]。 #### 代码环境搭建 为了顺利运行 OC-Sort,在本地环境中需安装必要的依赖库: ```bash pip install numpy opencv-python scikit-image filterpy torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu ``` #### 获取预训练模型 对于深度学习部分,建议使用已有的预训练权重文件来加速开发过程。例如 OSNet 模型可以按照如下方式获取: ```python import os from urllib.request import urlretrieve model_url = "https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY" destination_path = "~/.cache/torch/checkpoints/osnet_x1_0_imagenet.pth" if not os.path.exists(os.path.expanduser(destination_path)): urlretrieve(model_url, filename=os.path.expanduser(destination_path)) ``` #### 主要模块解析 ##### 初始化追踪器 创建 `Tracker` 类实例时初始化参数设置,包括但不限于最大年龄、最小命中次数等超参配置。 ##### 数据关联逻辑 通过卡尔曼滤波预测状态估计值,并利用匈牙利算法完成检测框与现有轨迹之间的匹配操作。 ##### 轨迹更新机制 当新帧到来时,根据当前时刻观测到的对象位置信息调整已有轨迹的状态变量;如果某条轨迹连续丢失超过设定阈值,则将其标记为删除。 #### 完整代码示例 下面给出一段简化版本的 Python 实现作为参考: ```python class Tracker: def __init__(self): self.tracks = [] def predict(self): """Predict states of existing tracks.""" pass def update(self, detections): """Update tracked objects with new detection results.""" pass def main(video_source="/path/to/video"): tracker = Tracker() cap = cv2.VideoCapture(video_source) while True: ret, frame = cap.read() if not ret: break # Perform object detection here... dets = [] # Detection bounding boxes should be stored as list of tuples like [(x,y,w,h,score)] tracker.predict() # Predict the state of all active tracks. matches, unmatched_detections, unmatched_tracks = associate(detections=dets, trackers=self.tracks) # Update matched tracks using Kalman Filter predictions combined with measurements from current frame's detections. visualize(frame, self.tracks) if __name__ == "__main__": video_file = "/home/user/videos/sample_video.mp4" main(video_source=video_file) ``` 上述代码仅为框架示意,具体细节还需参照官方文档进一步完善。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值