datasets.py
import glob
import random
import os
import sys
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from utils.augmentations import horisontal_flip
from torch.utils.data import Dataset
import torchvision.transforms as transforms
"""
本文件的主要作用 pad_to_square把图像调整为方形,resize调整图像大小,random_resize随机调整图像大小
ImageFolder读取data/samples下的所有图像,调整为方形,调整大小,
生成张量,为detect.py提供输入
"""
def pad_to_square(img, pad_value):
c, h, w = img.shape
dim_diff = np.abs(h - w)
pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
img = F.pad(img, pad, "constant", value=pad_value)
“pad:不同Tensor的填充方式
1.四维Tensor:传入四元素tuple(pad_l, pad_r, pad_t, pad_b),
指的是(左填充,右填充,上填充,下填充),其数值代表填充次数
2.六维Tensor:传入六元素tuple(pleft, pright, ptop, pbottom, pfront, pback),
指的是(左填充,右填充,上填充,下填充,前填充,后填充),其数值代表填充次数”
return img, pad
def resize(image, size):
image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
return image
def random_resize(images, min_size=288, max_size=448):
new_size = random.sample(list(range(min_size, max_size + 1, 32)), 1)[0]
images = F.interpolate(images, size=new_size, mode="nearest")
return images
class ImageFolder(Dataset):
def __init__(self, folder_path, img_size=416):
self.files = sorted(glob.glob("%s/*.*" % folder_path))
self.img_size = img_size
def __getitem__(self, index):
img_path = self.files[index % len(self.files)]
img = transforms.ToTensor()(Image.open(img_path))
img, _ = pad_to_square(img, 0)
img = resize(img, self.img_size)
return img_path, img
def __len__(self):
return len(self.files)
class ListDataset(Dataset):
"""
self.img_files 得到图片路径:data/custom/images/train.jpg
self.label_files 得到标签路径:data/custom/images/train.txt
"""
def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
with open(list_path, "r") as file:
self.img_files = file.readlines()
self.label_files = [
path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
for path in self.img_files
]
self.img_size = img_size
self.max_objects = 100
self.augment = augment
self.multiscale = multiscale
self.normalized_labels = normalized_labels
self.min_size = self.img_size - 3 * 32
self.max_size = self.img_size + 3 * 32
self.batch_count = 0
def __getitem__(self, index):
img_path = self.img_files[index % len(self.img_files)].rstrip()
img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))
if len(img.shape) != 3:
img = img.unsqueeze(0)
img = img.expand((3, img.shape[1:]))
_, h, w = img.shape
h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
img, pad = pad_to_square(img, 0)
_, padded_h, padded_w = img.shape
label_path = self.label_files[index % len(self.img_files)].rstrip()
targets = None
if os.path.exists(label_path):
boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)
y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
x1 += pad[0]
y1 += pad[2]
x2 += pad[1]
y2 += pad[3]
boxes[:, 1] = ((x1 + x2) / 2) / padded_w
boxes[:, 2] = ((y1 + y2) / 2) / padded_h
boxes[:, 3] *= w_factor / padded_w
boxes[:, 4] *= h_factor / padded_h
targets = torch.zeros((len(boxes), 6))
targets[:, 1:] = boxes
if self.augment:
if np.random.random() < 0.5:
img, targets = horisontal_flip(img, targets)
return img_path, img, targets
"""
targets的第一位是0
一张image对应的n个target(这个是张量),并且target[:,0]=0(即前面提到的targets的第一位是0),
target[:,0]表示的是对应image的ID。在训练的时候collate_fn函数都会把所有target融合在一起成为
一个张量(targets = torch.cat(targets, 0)),只有这个张量的第一位(target[:,0])才可以判断
这个target属于哪一张图片(即能够匹配图像ID)。
"""
def collate_fn(self, batch):
paths, imgs, targets = list(zip(*batch))
targets = [boxes for boxes in targets if boxes is not None]
for i, boxes in enumerate(targets):
boxes[:, 0] = i
targets = torch.cat(targets, 0)
if self.multiscale and self.batch_count % 10 == 0:
self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
imgs = torch.stack([resize(img, self.img_size) for img in imgs])
self.batch_count += 1
return paths, imgs, targets
def __len__(self):
return len(self.img_files)