一、修改计划
考虑到HOT是多光谱数据,通道数量为16,传统RGB算法输入通道为3,因此需要修改一下几个位置
1、数据读取方法
2、数据增强方法
3、网络第一层的inchannel
4、如何创建新的跟踪网络(以PCA实现降维+RGB算法为例)
二、/ltr/ 代码修改
1、修改 /ltr/admin/local.py,添加 self.hot_dir
2、修改 /ltr/dataset/__init__.py,添加 from .hot import HOT
3、在 /ltr/dataset/ 内创建 hot.py,和之前创建的 hot_false.py 基本保持一致,主要注意数据加载方式不再是 jpeg4py_loader,还要注意数据类型变成了png
4、修改 /ltr/run_training.py,参数中加入一行
parser.add_argument('--run_name', type=str, default='res18', help='run name')
修改保存路径,避免训练的模型覆盖
settings.project_path = 'ltr/{}/{}'.format(train_module, train_name+run_name)
5、修改 /ltr/train_settings/bbreg/atom.py
from ltr.dataset import HOT_False, HOT
修改 hot_train、hot_val、dataset_train、dataset_val
6、修改 /ltr/data/image_loader.py,加入def hotpng_loader 以处理hot的多光谱数据,根据官方提供的数据处理代码Hyperspectral Object Tracking Challenging 2023编写
def hotpng_loader(path):
try:
data = Image.open(path)
img = np.array(data)
# Parameters
M, N = img.shape
B = [4, 4]
skip = [4, 4]
bandNumber = 16
col_extent = N - B[1] + 1
row_extent = M - B[0] + 1
# Get Starting block indices
start_idx = np.arange(B[0])[:, None] * N + np.arange(B[1])
# Generate Depth indeces
didx = M * N * np.arange(1)
start_idx = (didx[:, None] + start_idx.ravel()).reshape((-1, B[0], B[1]))
# Get offsetted indices across the height and width of input array
offset_idx = np.arange(row_extent)[:, None] * N + np.arange(col_extent)
# Get all actual indices & index into input array for final output
out = np.take(img, start_idx.ravel()[:, None] + offset_idx[::skip[0], ::skip[1]].ravel())
out = np.transpose(out)
DataCube = out.reshape(M//4, N//4, bandNumber) # (272, 512, 16) <class 'numpy.int32'>
cube = ((DataCube - DataCube.min()) / (DataCube.max() - DataCube.min()) * 255).astype(np.uint8)[:,:,:3] # 先取前三个方便测试整体流程
return cube
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print(e)
return None
***** 处理速度太慢,怀疑是2D转3D的问题,考虑全都存储为3D的mat,之后直接读取mat
7、在 /ltr/data/image_loader.py 中加入def hotmat_loader
def hotmat_loader(path):
try:
img_mat = scio.loadmat(path)['data']
cube = ((DataCube - DataCube.min()) / (DataCube.max() - DataCube.min()) * 255).astype(np.uint8)[:,:,:3]
return cube
except Exception as e:
print('ERROR: Could not read image "{}"'.format(path))
print(e)
return None