get_max_preds()函数主要从热力图中获取关键点的坐标,完整函数代码如下
def get_max_preds(batch_heatmaps):
'''
get predictions from score maps
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
'''
assert isinstance(batch_heatmaps, np.ndarray), \
'batch_heatmaps should be numpy.ndarray'
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
width = batch_heatmaps.shape[3]
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)#argmax返回最大值的索引
maxvals = np.amax(heatmaps_reshaped, 2)#返回数组的最大值,在第2个维度
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)#np.tile沿着指定维度复制
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)
preds *= pred_mask
return preds, maxvals
该函数在evaluate.py中被调用
而该accuracy()在function.py中被调用;output = model(input),为模型的输出
train函数则在train.py中被调用
为了直观方便的看到训练过程中的数据传递,将batchsize改为了1,并输出了代码中各变量的值
assert isinstance(batch_heatmaps, np.ndarray), \
'batch_heatmaps should be numpy.ndarray'
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
width = batch_heatmaps.shape[3]
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
此时batch_heatmaps=output,其维度为torch.Size([1, 17, 64, 48]),1表示batch_size,17表示17个关键点,64和48是每个关键点对应热力图的大小
batch_heatmaps.shape=torch.Size([1, 17, 64, 48])
batch_size=1
num_joints=17
width=48
heatmaps_reshaped.shape=torch.Size([1, 17, 3072])
idx = np.argmax(heatmaps_reshaped, 2)#argmax返回最大值的索引
maxvals = np.amax(heatmaps_reshaped, 2)#返回数组的最大值,在第2个维度
idx为热力图中最大值的索引,maxvals为热力图中的最大值;因为17个关键点对应17个热力图,每个热力图求一个最大值(即获取一个关键点),所以总共为17个值
idx=[ [2343 2475 1715 1566 1861 2432 2197 2050 2102 1565 2287 2529 2344 2288, 2768 2058 2910] ]
maxvals=[ [2.039913 1.7724268 1.2334661 0.97012115 1.1990129 1.5435572, 0.8566811 1.2814575 1.0773002 0.91459084 1.4435731 1.5522699, 1.7481376 1.8977993 1.3567188 1.2680566 1.0704352 ] ]
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
maxvals=[ [ [2.039913 ], [1.7724268 ], [1.2334661 ], [0.97012115], [1.1990129 ], [1.5435572 ], [0.8566811 ], [1.2814575 ], [1.0773002 ], [0.91459084], [1.4435731 ], [1.5522699 ], [1.7481376 ], [1.8977993 ], [1.3567188 ], [1.2680566 ], [1.0704352 ] ] ]
idx=[ [ [2343], [2475], [1715], [1566], [1861], [2432], [2197], [2050], [2102], [1565], [2287], [2529], [2344], [2288], [2768], [2058], [2910] ] ]
此时maxvals.shape=idx.shape=([1, 17, 1])
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)#np.tile沿着指定维度复制
preds=[ [ [2343. 2343.], [2475. 2475.], [1715. 1715.], [1566. 1566.], [1861. 1861.], [2432. 2432.], [2197. 2197.], [2050. 2050.], [2102. 2102.], [1565. 1565.], [2287. 2287.], [2529. 2529.], [2344. 2344.], [2288. 2288.], [2768. 2768.], [2058. 2058.], [2910. 2910.] ] ]
np.tile()沿着第一维和第二维复制1倍(即不变),将第三维复制使其变为原来的2倍
此时preds.shape=([1, 17, 2])
preds[:, :, 0] = (preds[:, :, 0]) % width
这行代码是求关键点在热力图(热力图大小为48*64)中的横坐标,用热力图的索引除以热力图的宽度,再取余;索引、坐标点是从0开始的,除法取余数也是从0开始,所以直接取余数刚刚好
preds=[ [ [ 39. 2343.], [ 27. 2475.], [ 35. 1715.], [ 30. 1566.], [ 37. 1861.], [ 32. 2432.], [ 37. 2197.], [ 34. 2050.], [ 38. 2102.], [ 29. 1565.], [ 31. 2287.], [ 33. 2529.], [ 40. 2344.], [ 32. 2288.], [ 32. 2768.], [ 42. 2058.], [ 30. 2910.] ] ]
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
这行代码是求关键点在热力图(热力图大小为48*64)中的纵坐标,用热力图的索引除以热力图的宽度,再取商;np.floor()函数的作用是向下取整,因为坐标点是从0开始计的,所以要向下取整
preds=[ [ [39. 48.], [27. 51.], [35. 35.], [30. 32.], [37. 38.], [32. 50.], [37. 45.], [34. 42.], [38. 43.], [29. 32.], [31. 47.], [33. 52.], [40. 48.], [32. 47.], [32. 57.], [42. 42.], [30. 60.] ] ]
此处以第一个关键点[2343. 2343.]为例,通过取余与取商得到关键点在热力图中的位置为[39. 48.]
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
np.greater(maxvals, 0.0)=tensor([ [ [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1] ] ], dtype=torch.uint8)
np.greater()用于判断maxvals中的值是否大于0.0,返回一个与maxvals同维度的tensor,如果maxvals某一位置的值大于0.0,则返回的tensor相应位置取1,反之取0
np.tile()将第一维和第二维复制使其变为原来的1倍(即不变),将第三维复制使其变为原来的2倍
pre_mask=[ [ [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1], [1 1] ] ]
pred_mask = pred_mask.astype(np.float32)
pred_mask=[ [ [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.], [1. 1.] ] ]
preds *= pred_mask
preds=[ [ [39. 48.], [27. 51.], [35. 35.], [30. 32.], [37. 38.], [32. 50.], [37. 45.], [34. 42.], [38. 43.], [29. 32.], [31. 47.], [33. 52.], [40. 48.], [32. 47.], [32. 57.], [42. 42.], [30. 60.] ] ]
return preds, maxvals
最后返回关键点的坐标与热力图的最大值