本文介绍简单的自定义Dataset,供模板使用
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
class MyDataset(Dataset):
def __init__(self, image_lines):
# 下文 txt文件中的每行
self.image_lines = image_lines
def __len__(self):
return len(self.image_lines)
def __getitem__(self, index):
"""
此方法返回值:dataset[index]
裁剪、变形等操作可以在此方法中进行
这里只简单返回image、box、label
line: ["/path/root", "100,200,300,400,label"]
image: 读取到的图片
box: [100,200,300,400]
label: "label"
"""
line = self.image_lines[index].split()
image = np.array(Image.open(line[0]))
box = np.array(list(map(int, line[1].split(",")[:-1])))
label = line[1].split(",")[-1]
return image, box, label
if __name__ == '__main__':
# txt文件中每行的存储内容:/path/root 100,200,300,400,label
train_images = "2022_train.txt"
with open(train_images) as f:
train_image_lines = f.readlines()
dataset = MyDataset(train_image_lines)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
for batch_image, batch_boxes, batch_labels in dataloader:
print(batch_image.size())
print(batch_boxes[:3])
print(batch_labels[:3])