- 训练数据前需要对数据集进行处理,这里利用torch.utils.data下的两个类,DataLoader和Dataset
- Dataset是封封装图像和标签,每次输出一张图像和对应的标签,[数据增强也在此实现]
- DataLoader 指明了Dataset和batchsize shuffle 和 collate_fn
1.Dataset
1.1 基本结构
from torch.utils.data import DataLoader,Dataset
class LoadImagesAndLables(Dataset):
def __init__(self,img_path):
self.imgs = img_path
def __len__(self):
return len(self.imgs)
def __getitem__(self,index):
return torch.from_numpy(imgs),torch.from_numpy(labels)
1.2 实现基本功能
- 初步实现Dataset功能,获取图像,解析标签
- 图像返回格式为RGB ,通道为CWH,并且归一化
- 标签size= (标签数量,6) ,其中6列分别为 图像索引,类别,x,y,w,h
- collate_fn 是DataLoader中用到的,目的是对同一个batch内的数据进行打包,图像利用torch.stack(img,0)转换成(16,3,512,512)。标签利用torch.cat(label,0)转换成(N,6),其中N 为当前batch下的所有标注数量
import torchvision.datasets
import torch
from torch.utils.data import DataLoader,Dataset
from pathlib import Path
import numpy as np
import cv2
"""
***Dataset***
创建Dataset子类,用来创建image 和 label
"""
class LoadImagesAndLabels(torch.utils.data.Dataset):
"""
# v1.0 2021.12.30 by wjl
# 1. 输入图像txt 和 标签 txt ,获取所有图像和标签
# 2. 默认输入图像和标签大小为 512 即不用调整
# 3. 输出图像为tensor 类型 标签shape=[box_num,6] [[img_index,cls,x,y,w,h],[img_index,cls,x,y,w,h],...]
"""
def __init__(self,txt_path:str,img_size = 512):
self.img_txt_path = txt_path
self.imgs_path = None
self.labels_path = None
self.labels = {}
self.get_imgs_and_labels()
pass
def __len__(self):
return len(self.imgs_path)
def __getitem__(self,index):
img_name = self.imgs_path[index]
img = cv2.imread(img_name)
img = img[:,:,::-1].transpose(2,0,1)
img = img/255.0
label = self.labels[img_name]
label_out = torch.zeros((len(label),6))
label_out[:,1:] = torch.from_numpy(label)
return torch.from_numpy(img), label_out
def get_imgs_and_labels(self):
p = Path(self.img_txt_path)
assert p.suffix == ".txt" , "image path not txt"
with open(self.img_txt_path,'r') as f:
self.imgs_path = f.readlines()
self.imgs_path = [img_name.strip() for img_name in self.imgs_path]
assert self.imgs_path ,"No image found!"
self.labels_path = [lp.replace('.jpg','.txt') for lp in self.imgs_path]
for img_name in self.imgs_path:
label_path = img_name.replace('.jpg','.txt')
with open(label_path,'r') as f:
local = np.array([x.split() for x in f.read().strip().splitlines()],dtype = np.float32)
self.labels[img_name] = local
@staticmethod
def collate_fn(batch):
img,label = zip(*batch)
for i,l in enumerate(label):
l[:,0] = i
return torch.stack(img,0), torch.cat(label,0)
2. DataLoader
- DataLoader是在train中用到的,每次提供batchsize大小的数据,同时利用shuffle进行图像顺序打乱。
path = "/home/.../ImageSets/Main/train.txt"
dataset = LoadImagesAndLabels(path)
train_loader = DataLoader(dataset = dataset,batch_size=16,shuffle=True,collate_fn=dataset.collate_fn)
for epoch in range(2):
for i, data in enumerate(train_loader):
img,label = data
print("epoch: {}, {} inputs size {} labels size {}".format(epoch,i,img.size(),label.size()))
epoch: 0, 0 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([261, 6])
epoch: 0, 1 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([274, 6])
epoch: 0, 2 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([265, 6])
epoch: 0, 3 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([281, 6])
epoch: 0, 4 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([294, 6])
epoch: 0, 5 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([278, 6])
epoch: 0, 6 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([283, 6])
epoch: 0, 7 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([261, 6])
epoch: 0, 8 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([285, 6])
epoch: 0, 9 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([267, 6])
epoch: 0, 10 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([229, 6])
epoch: 0, 11 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([275, 6])
epoch: 0, 12 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([284, 6])
epoch: 0, 13 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([257, 6])
epoch: 0, 14 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([287, 6])
epoch: 0, 15 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([274, 6])
epoch: 0, 16 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([257, 6])
epoch: 0, 17 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([257, 6])
epoch: 0, 18 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([236, 6])
epoch: 0, 19 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([254, 6])
epoch: 0, 20 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([290, 6])
epoch: 0, 21 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([259, 6])
epoch: 0, 22 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([285, 6])
epoch: 0, 23 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([257, 6])
epoch: 0, 24 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([279, 6])
epoch: 0, 25 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([225, 6])
epoch: 0, 26 inputs size torch.Size([3, 3, 512, 512]) labels size torch.Size([43, 6])
epoch: 1, 0 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([263, 6])
epoch: 1, 1 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([293, 6])
epoch: 1, 2 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([264, 6])
epoch: 1, 3 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([265, 6])
epoch: 1, 4 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([305, 6])
epoch: 1, 5 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([264, 6])
epoch: 1, 6 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([226, 6])
epoch: 1, 7 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([255, 6])
epoch: 1, 8 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([247, 6])
epoch: 1, 9 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([303, 6])
epoch: 1, 10 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([279, 6])
epoch: 1, 11 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([247, 6])
epoch: 1, 12 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([257, 6])
epoch: 1, 13 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([255, 6])
epoch: 1, 14 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([298, 6])
epoch: 1, 15 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([259, 6])
epoch: 1, 16 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([258, 6])
epoch: 1, 17 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([227, 6])
epoch: 1, 18 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([244, 6])
epoch: 1, 19 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([256, 6])
epoch: 1, 20 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([260, 6])
epoch: 1, 21 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([296, 6])
epoch: 1, 22 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([273, 6])
epoch: 1, 23 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([315, 6])
epoch: 1, 24 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([282, 6])
epoch: 1, 25 inputs size torch.Size([16, 3, 512, 512]) labels size torch.Size([267, 6])
epoch: 1, 26 inputs size torch.Size([3, 3, 512, 512]) labels size torch.Size([39, 6])
3. 训练数据
- 利用DataLoader获取batchsize个数据后,送入网络进行推理,获得的结果与label进行损失计算,根据优化策略,反向传播梯度,更新一次weight,实现小批次的训练过程。