官方文档:
https://docs.ultralytics.com/yolov5/tutorials/test_time_augmentation/
文章目录
- 1. TTA概念介绍
- 2. TTA代码实现
- 3. TTA使用方法
一句话简单的介绍Test Time Augmentation(TTA)
就是测试过程中也使用数据增强,官方教程介绍:Test-Time Augmentation (TTA) Tutorial
1. TTA概念介绍
在训练过程中数据增强是非常常用的一种手段,目的是为了提高模型的泛化能力,以免出现大小不一样,图像选择一下就分辨不出来的尴尬。那么TTA就是想在推理阶段也进行数据增强。不过不会太复杂,因为会增加额外的计算量,在打比赛的时候可能会用到,因为打比赛不在意你的推理时长是多久,所以可以尽情瞎造;但是在实际部署的情况下,因为推理速度减慢很可能会达不到实时监测的效果,所以实际是没有必要在推理也进行数据增强的,会降低监测速度。
2. TTA代码实现
知道了其原理是在推理阶段使用数据增强,那么很明显,其将在model中的前向传播过程中实现。在yolov5中,TTA 自动集成到所有YOLOv5 PyTorch Hub模型中。具体的解析我已经写在了注释中。
yolov5实现代码:
class Model(nn.Module):
def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
super().__init__()
...
# 如果直接传入的是dict则无需处理; 如果不是则使用yaml.safe_load加载yaml文件
with open(cfg, errors='ignore') as f:
self.yaml = yaml.safe_load(f) # model dict
...
# 创建网络模型
# self.model: 初始化的整个网络模型(包括Detect层结构)
# self.save: 所有层结构中from不等于-1的序号,并排好序 [4, 6, 10, 14, 17, 20, 23]
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
...
def forward(self, x, augment=False, profile=False, visualize=False): # debug同样需要第三次才能正常跳进来
if augment: # use Test Time Augmentation(TTA), 如果打开会对图片进行scale和flip
return self._forward_augment(x) # augmented inference, None
return self._forward_once(x, profile, visualize) # single-scale inference, train
# 使用TTA进行推理(当然还是会调用普通推理实现前向传播)
def _forward_augment(self, x):
img_size = x.shape[-2:] # height, width
s = [1, 0.83, 0.67] # scales
f = [None, 3, None] # flips (2-ud上下flip, 3-lr左右flip)
y = [] # outputs
# 这里相当于对输入x进行3次不同参数的测试数据增强推理, 每次的推理结构都保存在列表y中
for si, fi in zip(s, f):
# scale_img缩放图片尺寸
# 通过普通的双线性插值实现,根据ratio来控制图片的缩放比例,最后通过pad 0补齐到原图的尺寸
xi = scale_img(x.flip