一般选择的未标记图片及其伪标签不直接送入网络进行retraining ,在此之前,我们需要对未标记图片及其伪标签做相应的增强操作,以便网络能学习额外的特征并且缓解对噪音的过拟合。
常见的增强操作:
1:裁剪
直接裁剪或者填充为指定大小
2:翻转
上下翻转、左右翻转、旋转指定角度
3:标准化
4:滤波
5:resize(插值法填充)
6:cutout
cutout是2017年提出的一种数据增强方法,想法比较简单,即在训练时随机裁剪掉图像的一部分,也可以看作是一种类似dropout的正则化方法。
Improved Regularization of Convolutional Neural Networks with Cutout
paper: https://arxiv.org/pdf/1708.04552.pdf
code: https://github.com/uoguelph-mlrg/Cutout
将以上操作使用PIL库实现并封装,方便后续调用。
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import random
import torch
from torchvision import transforms
import cv2
class DataAugmentations():
def __init__(self):
pass
def crop(self,img, mask, size):
# padding height or width if smaller than cropping size
w, h = img.size
padw = size - w if w < size else 0
padh = size - h if h < size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=255