本章详细讲解数据的处理问题,将coco数据集读取,以及之后自定义数据集的处理,
数据预处理思想
yolo3的数据集处理也是一大亮点,由于yolo3对数据集的输入有要求,指定的照片输入大小必须是416,所有对于不满足照片的大小有一系列的操作,如果直接resize操作,将直接损失照片信息,网络在学习分类的过程还要适应照片尺寸的问题,导致训练效果不佳,在yolo3中是先进行高和宽的调整一样大,在进行上采样的resize,同时要修改label的坐标位置,随机水平翻转,再一次随机变化大小,之后再变化到416的大小尺寸作为输入。
代码
class ListDataset(Dataset): #继承Dataset
def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
with open(list_path, "r") as file:
self.img_files = file.readlines()
self.label_files = [
path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt") #这一步是生成labels的位置
for path in self.img_files
]
self.img_size = img_size
self.max_objects = 100
self.augment = augment
self.multiscale = multiscale
self.normalized_labels = normalized_labels
self.min_size = self.img_size - 3 * 32
self.max_size = self.img_size + 3 * 32
self.batch_count = 0
def __getitem__(self, index):
# ---------
# Image
# ---------
img_path = self.img_files[index % len(self.img_files)].rstrip() #按照索引的方式找到对应的路径
# Extract image as PyTorch tensor
img = transforms.ToTensor()(Image.open(img_path).convert('RGB')) #读取照片
# Handle images with less than three channels
if len(img.shape) != 3:
img = img.unsqueeze(0)
img = img.expand((3, img.shape[1:]))
_, h, w = img.shape
h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1) #直接理解为照片的宽度和高度
# Pad to square resolution
img, pad = pad_to_square(img, 0) #这一步就是将高和宽变成一样大小
_, padded_h, padded_w = img.shape
# ---------
# Label
# ---------
label_path = self.label_files[index % len(self.img_files)].rstrip() #照片对应的label路径
targets = None
if os.path.exists(label_path):
boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
# Extract coordinates for unpadded + unscaled image
x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2) #label的坐标点位置是xywh所以先进行转化
y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
# Adjust for added padding
x1 += pad[0] #照片大小变化了所以框的坐标点需要修改
y1 += pad[2]
x2 += pad[1]
y2 += pad[3]
# Returns (x, y, w, h)
boxes[:, 1] = ((x1 + x2) / 2) / padded_w #在次重新转化xywh形式
boxes[:, 2] = ((y1 + y2) / 2) / padded_h
boxes[:, 3] *= w_factor / padded_w
boxes[:, 4] *= h_factor / padded_h
targets = torch.zeros((len(boxes), 6))
targets[:, 1:] = boxes
# Apply augmentations
if self.augment:
if np.random.random() < 0.5:
img, targets = horisontal_flip(img, targets) #随机水平翻转
return img_path, img, targets
def collate_fn(self, batch): #自定义类中的函数用于batch处理
paths, imgs, targets = list(zip(*batch)) #可以理解写这个函数必须写这个操作,就是将 __getitem__的输出作为列表,
# Remove empty placeholder targets
targets = [boxes for boxes in targets if boxes is not None]
# Add sample index to targets
for i, boxes in enumerate(targets):
boxes[:, 0] = i
targets = torch.cat(targets, 0) #增加一个维度,就可以是批次额处理
# Selects new image size every tenth batch
if self.multiscale and self.batch_count % 10 == 0:
self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32)) #对照片随机变大变小
# Resize images to input shape
imgs = torch.stack([resize(img, self.img_size) for img in imgs]) #在一次将照片大小变化为原来的416
self.batch_count += 1
return paths, imgs, targets
def __len__(self):
return len(self.img_files)
这一步将数据读取封装成一个类,中间还有一起其他的函数,
def pad_to_square(img, pad_value):
c, h, w = img.shape
dim_diff = np.abs(h - w)
# (upper / left) padding and (lower / right) padding
pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
# Determine padding
pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
# Add padding
img = F.pad(img, pad, "constant", value=pad_value)
return img, pad
这一步就是将高和宽整成一样大小,比如500,300输出就是500,500的大小,用pad1,pad2记录是高或者宽拉长了多少,用于框的位置修改。
def resize(image, size):
image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
return image
通过上采样的形式就修改了照片的尺寸,比直接进行resize的效果要好
def random_resize(images, min_size=288, max_size=448):
new_size = random.sample(list(range(min_size, max_size + 1, 32)), 1)[0]
images = F.interpolate(images, size=new_size, mode="nearest")
return images
这一步是将大小随机的变化,大小变化设置了一定范围
def horisontal_flip(images, targets):
images = torch.flip(images, [-1])
targets[:, 2] = 1 - targets[:, 2]
return images, targets
进行水平翻转的代码
其中有一个重点是__getitem__的输出,理解输出值得形式,可以自己从新再写一个dataset类的读取,官方给出的代码很一般,后面我会自己写一个csv文件的读取
看下输出值,以一张照片为例
print(img_path)
print(img.shape)
print(targets)
/images/train2014/COCO_train2014_000000000009.jpg
torch.Size([3, 768, 768])
tensor([[ 0.0000, 45.0000, 0.4795, 0.6416, 0.9556, 0.4466],
[ 0.0000, 45.0000, 0.7365, 0.3104, 0.4989, 0.3573],
[ 0.0000, 50.0000, 0.6371, 0.6747, 0.4941, 0.3829],
[ 0.0000, 45.0000, 0.3394, 0.4392, 0.6789, 0.5861],
[ 0.0000, 49.0000, 0.6468, 0.2244, 0.1180, 0.0727],
[ 0.0000, 49.0000, 0.7731, 0.2224, 0.0907, 0.0729],
[ 0.0000, 49.0000, 0.6683, 0.2952, 0.1313, 0.1102],
[ 0.0000, 49.0000, 0.6429, 0.1844, 0.1481, 0.1110]])
先看下yolo3训练时的数据集在文件夹内的放置
一个训练集下的数据集信息,images是照片,labels是每张照片对应的label信息,
classes是全部分类的名称,train保存训练图片的路径,valid是测试照片的路径
只有一张,照片名字对应label的名字
这个是train的label 一张照片有一个txt文件保存信息,一个txt可能包含多种框,这种事为了读取一张照片就将所有的框作为处理,
label内的信息,储存方式为xywh方式,坐标点的位置进行归一化了
这个是训练的txt,用来保存全部需要训练的照片路径,通过读取这一个文件来加载照片