gaitpart模块粗略分析
今天我们来分析第二个模块,水平金字塔池化
class HorizontalPoolingPyramid():
"""
Horizontal Pyramid Matching for Person Re-identification
Arxiv: https://arxiv.org/abs/1804.05275
Github: https://github.com/SHI-Labs/Horizontal-Pyramid-Matching
"""
def __init__(self, bin_num=None):
if bin_num is None:
bin_num = [16, 8, 4, 2, 1]
self.bin_num = bin_num
def __call__(self, x):
"""
x : [n, c, h, w]
ret: [n, c, p]
"""
n, c = x.size()[:2]
features = []
for b in self.bin_num:
z = x.view(n, c, b, -1)
z = z.mean(-1) + z.max(-1)[0]
features.append(z)
return torch.cat(features, -1)
首先,定义金字塔的层数和每一层的特征数量。
要么从配置文件中获取,如果没有定义,则自动赋值为五层,分别是16,8,4,2,1。
将输入的尺寸[n,c,h,w]转换为[n,c,b,-1]。将输入张量x重新排列成形状为[n, c, b, -1]的张量z。
然后,对于每个张量,计算其沿着最后一个维度(即每个子区域)的均值和最大值,并将它们相加,得到该特征表示。
最后,将所有层的特征表示连接起来,形成最终的输出张量,其维度为[n, c, p],其中p表示所有金字塔层的特征向量连接在一起的长度。
水平金字塔池化的作用在处理不同尺度的输入图像时,可以使用它来获得更加鲁棒和多尺度的特征表示