😎😎😎物体检测-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
7、LoadImagesAndLabels类的cache_labels函数
cache_labels方法可以缓存数据集的标签信息,同时验证图像和标签文件的有效性,这对于加快后续训练过程中的数据加载速度非常有帮助
cache_labels方法通过扫描数据集的图像和标签文件,验证它们的完整性,并将相关信息(包括图像尺寸、标签坐标、分割信息等)缓存起来,从而加快后续访练过程中数据的加载速度。通过预先检测和处理可能的问题(如损坏的文件、缺失的标签等),这个方法还有助于提前发现数据集中的潜在问题,避免在训练过程中出现意外中断。此外,通过计算文件列表的哈希值,可以轻松检测到数据集的变动,确保缓存信息的有效性
7.1 cache_labels函数
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
# Cache dataset labels, check images and read shapes
x = {} # dict
nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
for i, (im_file, lb_file) in enumerate(pbar):
try:
# verify images
im = Image.open(im_file)
im.verify() # PIL verify
shape = exif_size(im) # image size
segments = [] # instance segments
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
assert im.format.lower() in img_formats, f'invalid image format {im.format}'
# verify labels
if os.path.isfile(lb_file):
nf += 1 # label found
with open(lb_file, 'r') as f:
l = [x.split() for x in f.read().strip().splitlines()]
if any([len(x) > 8 for x in l]): # is segment
classes = np.array([x[0] for x in l], dtype=np.float32)
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
l = np.array(l, dtype=np.float32)
if len(l):
assert l.shape[1] == 5, 'labels require 5 columns each'
assert (l >= 0).all(), 'negative labels'
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
else:
ne += 1 # label empty
l = np.zeros((0, 5), dtype=np.float32)
else:
nm += 1 # label missing
l = np.zeros((0, 5), dtype=np.float32)
x[im_file] = [l, shape, segments]
except Exception as e:
nc += 1
print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
pbar.close()
if nf == 0:
print(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
x['hash'] = get_hash(self.label_files + self.img_files)
x['results'] = nf, nm, ne, nc, i + 1
x['version'] = 0.1 # cache version
torch.save(x, path) # save for next time
logging.info(f'{prefix}New cache created: {path}')
return x
- 定义函数,接受缓存文件路径和前缀字符串作为参数
- x,初始化一个空字典x,用于存储缓存的数据
- nm, nf, ne, nc,初始化计数器:nm(缺失的标签数量),nf(找到的标签数量),ne(空标签文件数量),nc(损坏的图像或标签数量)
- pbar ,创建一个进度条,迭代图像文件和标签文件的元组列表
- 遍历图像文件和标签文件
- try
- im ,读取图片
- 验证图像文件的完整性
- shape ,使用exif_size函数获取图像的尺寸
- segments ,初始化实例分割信息列表
- 确保图像尺寸大于9
- 确保图像在允许的格式列表内
- 如果存在标签文件:
- nf,找到的标签数量+1
- 打开标签文件
- l,按行读取标签文件,将每行使用空格分割成一个list,所以得list封装成一个list
- 如果存在分割信息(有任何标签行包含超过8个元素,这表示该行可能包含分割信息而不仅仅是简单的边界框信息):
- classes ,从每行标签中提取类别信息,并将其存储为一个NumPy数组
- segments ,对于包含分割信息的标签行,提取分割坐标并将其重新组织为(x,y)对的列表。每个分割信息被转换为一个NumPy数组
- l,将类别信息和通过segments2boxes函数从分割信息计算得到的边界框信息合并,以形成完整的标签信息。segments2boxes函数负责将分割信息转换为边界框格式
- l,确保标签转化为ndarray
- 如果处理后的标签数组l非空:
- 确保标签数组每行应该有5列,对应于(类别, x, y, 宽度, 高度)
- 确保所有标签值都应该非负
- 确保所有坐标标签值都应该在0到1之间,表示它们已经被归一化
- 确保没有重复的标签行,确保每个标签是唯一的
- 如果处理后的标签数组l为空:
- ne ,空标签计数器ne加
- l,对于空或缺失的标签,创建一个形状为(0, 5)的全为0的NumPy数组
- 将每个图像文件的标签信息、图像尺寸和分割信息存储在字典x中,以图像文件路径作为键
- 捕获处理过程中发生的任何异常
- nc ,损坏的图像或标签文件计数器nc加一
- 打印警告信息,指出忽略了损坏的图像或标签文件
- 更新进度条的描述信息,反映当前扫描的状态
- 完成所有文件的扫描后,关闭进度条
- 如果没有找到任何标签文件
- 打印警告信息,指出在指定路径中未找到任何标签文件,并可能提供一个帮助链接
- x[‘hash’],调用get_hash函数为所有的标签文件和图像文件列表生成哈希值,以检测数据集是否发生变化。这个哈希值被存储在字典x中,用于之后验证缓存数据的有效性
- x[‘results’],将找到的标签文件数量、缺失的标签数量、空的标签文件数量、损坏的文件数量以及处理的文件总数保存到字典x中的results键
- x[‘version’],设置缓存信息的版本号为0.1。这个版本号可以在未来用于管理缓存数据格式的变化
- 使用PyTorch的torch.save方法将包含缓存信息的字典x保存到指定的文件路径。这样,在后续的数据加载过程中,可以直接读取这个缓存文件,避免重复的文件检查和处理工作,从而加快数据准备速度
- 使用日志记录功能,输出一条信息,说明一个新的缓存文件已经被创建。这对于跟踪数据处理过程和调试是有帮助的
- 返回包含缓存信息的字典x。这个字典包含了重要的信息,如处理过的图像和标签文件的哈希值、找到的和缺失的标签文件数量、空的和损坏的文件数量、以及缓存的版本号等
7.2 exif_size函数
def exif_size(img):
# Returns exif-corrected PIL size
s = img.size # (width, height)
try:
rotation = dict(img._getexif().items())[orientation]
if rotation == 6: # rotation 270
s = (s[1], s[0])
elif rotation == 8: # rotation 90
s = (s[1], s[0])
except:
pass
return s
7.3 segments2boxes函数
def segments2boxes(segments):
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
boxes = []
for s in segments:
x, y = s.T # segment xy
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
return xyxy2xywh(np.array(boxes)) # cls, xywh
7.4 get_hash函数
这两个collate_fn
和collate_fn4
方法提供了灵活的批次数据处理机制,支持不同的数据增强策略。collate_fn
方法适用于常规的批次数据处理,而collate_fn4
则专门设计用于处理通过特定增强技术(如马赛克拼接)生成的四倍图像尺寸的数据。这种高度定制化的数据处理方法有助于提高模型对于不同尺寸和组合图像的适应能力,进而提升模型的泛化性能。
def get_hash(files):
# Returns a single hash value of a list of files
return sum(os.path.getsize(f) for f in files if os.path.isfile(f))
返回图像文件列表self.img_files
的长度,即数据集中的图像总数
8、LoadImagesAndLabels类的collate_fn、collate_fn4函数
8.1 len
def __len__(self):
return len(self.img_files)
8.2 collate_fn
@staticmethod
def collate_fn(batch):
img, label, path, shapes = zip(*batch) # transposed
for i, l in enumerate(label):
l[:, 0] = i # add target image index for build_targets()
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
collate_fn
方法为静态方法,这意味着它可以不通过实例调用,而是直接通过类调用- 定义了
collate_fn
方法,它用于自定义如何将多个数据样本组合成一个批次。这在使用PyTorch的DataLoader
时非常有用 - 将批次中的数据(图像、标签、路径、形状)解压,使得相同类型的数据被组织在一起
- 遍历批次中的每个标签
- 将标签中目标图像的索引设置为批次中的索引
i
,这对于之后构建目标时识别每个标签属于哪张图像是必要的 - 返回一个包含所有图像堆叠成一个张量、所有标签连接成一个张量、所有路径和所有形状的元组
8.3 collate_fn4
@staticmethod
def collate_fn4(batch):
img, label, path, shapes = zip(*batch) # transposed
n = len(shapes) // 4
img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
ho = torch.tensor([[0., 0, 0, 1, 0, 0]])
wo = torch.tensor([[0., 0, 1, 0, 0, 0]])
s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale
for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
i *= 4
if random.random() < 0.5:
im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[
0].type(img[i].type())
l = label[i]
else:
im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
img4.append(im)
label4.append(l)
for i, l in enumerate(label4):
l[:, 0] = i # add target image index for build_targets()
return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4
- 声明
collate_fn4
为静态方法 - 定义了
collate_fn4
方法,它是一个专门为处理四倍图像尺寸设计的数据组合函数 - 同
collate_fn
方法中的操作,将批次中的数据解压 - 计算新批次的大小,由于是四倍图像尺寸,所以新批次的大小是原批次大小的四分之一
- 初始化新的图像列表、标签列表,并从原批次的路径和形状列表中获取前
n
个元素 - 遍历新批次的每个元素
- 随机决定是直接放大第一个图像还是将四个图像拼接成一个更大的图像
- 如果决定放大图像,则使用双线性插值方法将图像尺寸放大两倍
- 如果决定拼接图像,则将四个图像拼接成一个更大的图像
- 将对应的标签也进行相应的调整和拼接,以匹配拼接后的图像
- 遍历处理后的所有标签集合
label4
。在这一步,每个l
代表一个标签集,其中包含了可能经过拼接或放大图像后对应的所有标签信息 - 对每个标签集内的标签,将第一列(通常用于指定目标图像索引)设置为新的批次索引
i
。这是为了确保在后续处理中,每个标签能够被正确关联到它所属的图像上 - 返回处理后的批次数据。使用
torch.stack(img4, 0)
将所有处理后的图像堆叠成一个新的张量,torch.cat(label4, 0)
将所有标签集连接成一个新的张量。path4
和shapes4
分别包含了处理后图像的路径和形状信息。这样,collate_fn4
方法生成了适合于训练过程的批次数据,其中图像可能经过了拼接或放大处理,以及相应的标签也进行了调整