这里记录一下《动手学深度学习pytorch版》里关于生成多个锚框的代码注释,不一定正确,但我不想再动脑子了。
一、生成锚框的宽和高的计算
这一块参考了哔哩哔哩沐神视频评论区一位大佬的回答。
如这位大佬所言,这里的宽高比 r 指的是锚框的长宽比,而缩放比 s 实际是以图片高度的 s 倍生成一个 hs*hs 的基准框,最后根据宽高比 r 来resize 成最后的锚框。
这个计算过程大概是这样的,w和h是原图像的宽和高,w0和h0是我们得到的锚框的宽和高,计算结果和教材上的结果是相同的。
这里做了归一化操作,和后面的代码对应上,做了h0/h,w0/w 这一步,第一次看其实不理解,要对应后面的代码来看,在调用 d2l.show_bboxes 这个函数时,乘了一个bbox_scale,想不到吧,前面除的h和w,最后又给你乘回来了。
二、每个像素点有多少个锚框(n+m-1)
这里n-1就是因为(s1,r1)重复了两次。
三、代码分析
#@save
def multibox_prior(data, sizes, ratios):
"""生成以每个像素为中心具有不同形状的锚框"""
# 取data的最后两个维度,这里data的shape为(batch_size,channels,height,width)
in_height, in_width = data.shape[-2:]
# num_sizes和num_ratios就是长宽比和纵横比的个数,教材上提到就是n和m
device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
# 对应一个像素就有n+m-1个锚框
boxes_per_pixel = (num_sizes + num_ratios - 1)
size_tensor = torch.tensor(sizes, device=device)
ratio_tensor = torch.tensor(ratios, device=device)
# 为了将锚点移动到像素的中心,需要设置偏移量。
# 因为一个像素的高为1且宽为1,我们选择偏移我们的中心0.5
offset_h, offset_w = 0.5, 0.5
steps_h = 1.0 / in_height # 在y轴上缩放步长
steps_w = 1.0 / in_width # 在x轴上缩放步长
# 生成锚框的所有中心点
center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h
center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w
# torch.meshgrid(tensor1,tensor2) 生成坐标
# 最后shift_x和shift_y的形状为size(h*w)
shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')
shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)
# 生成“boxes_per_pixel”个高和宽,对应前面的公式
# 之后用于创建锚框的四角坐标(xmin,xmax,ymin,ymax)
w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),
sizes[0] * torch.sqrt(ratio_tensor[1:])))\
* in_height / in_width
h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
sizes[0] / torch.sqrt(ratio_tensor[1:])))
# 除以2来获得半高和半宽
anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(
in_height * in_width, 1) / 2
来解释一下这句代码的含义,分解一下。
torch.stack((-w, -h, w, h))
这里将 -w,-h,w,h 四个张量在新增的维度dim=0上进行拼接,最后输出张量的shape为(4,n+m-1)
torch.stack((-w, -h, w, h)).T
对张量进行转置,最后的shape为 (n+m-1, 4)
torch.stack((-w, -h, w, h)).T.repeat(in_height * in_width, 1) / 2
在dim=0上重复(in_height * in_width)次,即h*w次,最后输出的shape为 ((n+m-1) *h * w,4)
# 每个中心点都将有“boxes_per_pixel”个锚框,
# 所以生成含所有锚框中心的网格,重复了“boxes_per_pixel”次
out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
dim=1).repeat_interleave(boxes_per_pixel, dim=0)
再来解释一下这句代码的含义,分解一下。
torch.stack([shift_x, shift_y, shift_x, shift_y], dim=1)
这里将 中心点shift_x,shift_y,shift_x,shift_y 四个张量在新增的维度dim=1上进行拼接,最后输出张量的shape为(h*w,4)
torch.stack([shift_x, shift_y, shift_x, shift_y], dim=1) .repeat_interleave(boxes_per_pixel, dim=0)
这里在dim=0上按输入张量每行元素重复 n+m-1 次,注意是一行元素重复了n+m-1次后,再对下一行进行重复,这里达到的效果就是生成对于每个像素生成n+m-1个锚框的中心点,最后输出的shape((n+m-1) *h * w,4)
output = out_grid + anchor_manipulations
return output.unsqueeze(0) # 在dim=0新增一个维度,代表batch_size所在的维度
最后返回的shape为(1,(n+m-1) *h * w,4),即(批量大小,锚框数量,4)。这里4代表的就是锚框的坐标,(左上x,左上y,右下x,右下y)。