PyTracking 训练 HOT(一)

一、修改计划

考虑到HOT是多光谱数据,通道数量为16,传统RGB算法输入通道为3,因此需要修改一下几个位置

1、数据读取方法

2、数据增强方法

3、网络第一层的inchannel

4、如何创建新的跟踪网络(以PCA实现降维+RGB算法为例)

二、/ltr/ 代码修改

1、修改 /ltr/admin/local.py,添加 self.hot_dir

2、修改 /ltr/dataset/__init__.py,添加 from .hot import HOT

3、在 /ltr/dataset/ 内创建 hot.py,和之前创建的 hot_false.py 基本保持一致,主要注意数据加载方式不再是 jpeg4py_loader,还要注意数据类型变成了png

4、修改 /ltr/run_training.py,参数中加入一行

parser.add_argument('--run_name', type=str, default='res18', help='run name')

修改保存路径,避免训练的模型覆盖

settings.project_path = 'ltr/{}/{}'.format(train_module, train_name+run_name)

5、修改 /ltr/train_settings/bbreg/atom.py

from ltr.dataset import HOT_False, HOT

修改 hot_train、hot_val、dataset_train、dataset_val

6、修改 /ltr/data/image_loader.py,加入def hotpng_loader 以处理hot的多光谱数据,根据官方提供的数据处理代码Hyperspectral Object Tracking Challenging 2023编写

def hotpng_loader(path):
    try:
        data = Image.open(path)
        img = np.array(data)
        
        # Parameters
        M, N = img.shape
        B = [4, 4]
        skip = [4, 4]
        bandNumber = 16
        col_extent = N - B[1] + 1
        row_extent = M - B[0] + 1
        # Get Starting block indices
        start_idx = np.arange(B[0])[:, None] * N + np.arange(B[1])
        # Generate Depth indeces
        didx = M * N * np.arange(1)
        start_idx = (didx[:, None] + start_idx.ravel()).reshape((-1, B[0], B[1]))
        # Get offsetted indices across the height and width of input array
        offset_idx = np.arange(row_extent)[:, None] * N + np.arange(col_extent)
        # Get all actual indices & index into input array for final output
        out = np.take(img, start_idx.ravel()[:, None] + offset_idx[::skip[0], ::skip[1]].ravel())
        out = np.transpose(out)
        DataCube = out.reshape(M//4, N//4, bandNumber)  # (272, 512, 16)    <class 'numpy.int32'>
        
        cube = ((DataCube - DataCube.min()) / (DataCube.max() - DataCube.min()) * 255).astype(np.uint8)[:,:,:3] # 先取前三个方便测试整体流程
        return cube
    except Exception as e:
        print('ERROR: Could not read image "{}"'.format(path))
        print(e)
        return None

 ***** 处理速度太慢,怀疑是2D转3D的问题,考虑全都存储为3D的mat,之后直接读取mat

7、在 /ltr/data/image_loader.py 中加入def hotmat_loader

def hotmat_loader(path):
    try:
        img_mat = scio.loadmat(path)['data']
        cube = ((DataCube - DataCube.min()) / (DataCube.max() - DataCube.min()) * 255).astype(np.uint8)[:,:,:3]
        return cube
    except Exception as e:
        print('ERROR: Could not read image "{}"'.format(path))
        print(e)
        return None
  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值