目录
前言
本文参考了Pytorch官方文档:Tutorials > Datasets & DataLoaders
为了使数据集代码与模型训练代码分离,提升可读性和模块化,PyTorch 提供了两个数据原语:
- torch.utils.data.DataLoader:存储样本及标签
- torch.utils.data.Dataset:将一个迭代器包装在Dataset周围,以便更容易访问样本
官方解释
Dataset
如何创建自定义的Dataset
一个自定义的Dataset类必须实现三个函数:__init__
, __len__
和 __getitem__
。
__init__
:该函数在实例化 Dataset 对象时运行一次。__len__
:该函数返回数据集中的样本数。__getitem__
:该函数从给定索引idx的数据集中加载并返回样本。
用如下代码demo进行解释:
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
"""
FashionMNIST图像存储在img_dir目录中,它们的标签分别存储在 CSV 文件 annotations_file 中。
"""
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
"""
包含图像、注释文件及对两者的transform
"""
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
"""
根据索引,该函数识别图像在磁盘上的位置,使用 read_image 将其转换为张量,
从 self.img_labels 中的 csv 数据中检索相应的标签,调用它们的转换函数,
并返回张量图像 和元组中的相应标签。
"""
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
DataLoader
DataLoader
类是从泛型类型的抽象基类typing.Generic继承的子类。从dir
可以看到其只暴露出两个可供调用的类方法,但这两个方法并不常用。但其初始化的实例,具有很多可以调用的方法。
实战例子
加载lmdb数据,并batch输出(CLIP使用的图像数据)
from torch.utils.data import Dataset, DataLoader
import lmdb
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from PIL import Image
from io import BytesIO
import base64
class LMDBImgInferenceDataset(Dataset):
def __init__(self, lmdb_path, base_idx):
self.base_idx = base_idx
assert os.path.isdir(lmdb_path), "The LMDB directory {} does not exist!".format(lmdb_path)
self.env = lmdb.open(lmdb_path, readonly=True, create=False, lock=False, readahead=False, meminit=False)
print("self.env.stat(): ", self.env.stat())
self.txn = self.env.begin(buffers=True)
self.transform = self._image_transform()
super(LMDBImgInferenceDataset, self).__init__()
def _image_transform(self, image_size=224):
transform = Compose([
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
lambda image:image.convert('RGB'),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
return transform
def __len__(self):
return self.env.stat()["entries"]+200
def __del__(self):
if hasattr(self, 'env'):
print("self.env close!")
self.env.close()
def __getitem__(self, idx):
idx+=self.base_idx
for _ in range(2):
img_mv = self.txn.get(str(idx).encode())
if img_mv:
image = Image.open(BytesIO(base64.urlsafe_b64decode(img_mv)))
img_arr = self.transform(image)
return idx,img_arr
else:
idx=self.base_idx
lmdb_path = "/mnt/workspace/10_canyin_clip/data/lingshou/imgs"
lmdb_data = LMDBImgInferenceDataset(lmdb_path, 3100293)
lmdb_dataloader = DataLoader(lmdb_data, batch_size=64, shuffle=False)
data_iter = iter(lmdb_dataloader)
加载url-txt list数据
__getitem__不带异常处理的
from torch.utils.data import Dataset, DataLoader
import lmdb
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from PIL import Image
from io import BytesIO
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from urllib import request
def url2pil(img_url):
user_agent = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/66.0.3359.181 Safari/537.36'
headers = {'user-agent': user_agent}
req = request.Request(url=img_url, headers=headers)
response = request.urlopen(req, timeout=30)
img= Image.open(BytesIO(response.read())).convert('RGB')
return img
class UrlImgInferenceDataset(Dataset):
def __init__(self, url_txt_list):
self.url_txt_list = url_txt_list
self.transform = self._image_transform()
super(UrlImgInferenceDataset, self).__init__()
def _image_transform(self, image_size=224):
transform = Compose([
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
lambda image:image.convert('RGB'),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
return transform
def __len__(self):
return len(self.url_txt_list)
def __getitem__(self, idx):
image = url2pil(self.url_txt_list[idx][0])
print("processing: {0}".format(idx))
txt = self.url_txt_list[idx][1]
img_arr = self.transform(image)
return img_arr,txt
url_txt_list = list(zip(url_list,txt_list))
urlimg_data = UrlImgInferenceDataset(url_txt_list)
lmdb_dataloader = DataLoader(urlimg_data, batch_size=20, shuffle=False)
data_iter = iter(lmdb_dataloader)
__getitem__带异常处理的
from torch.utils.data import Dataset, DataLoader
import lmdb
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from PIL import Image
from io import BytesIO
import base64
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from urllib import request
def url2pil(img_url):
user_agent = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/66.0.3359.181 Safari/537.36'
headers = {'user-agent': user_agent}
req = request.Request(url=img_url, headers=headers)
response = request.urlopen(req, timeout=30)
img= Image.open(BytesIO(response.read())).convert('RGB')
return img
class UrlImgInferenceDataset(Dataset):
def __init__(self, url_txt_list):
self.url_txt_list = url_txt_list
self.transform = self._image_transform()
super(UrlImgInferenceDataset, self).__init__()
def _image_transform(self, image_size=224):
transform = Compose([
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
lambda image:image.convert('RGB'),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
return transform
def __len__(self):
return len(self.url_txt_list)
def __getitem__(self, idx):
need_loop = True
pic_url = self.url_txt_list[idx][0]
try:
image = url2pil(pic_url)
except Exception as e:
print("加载url出错, 错误类型: {0}, 错误url: {1}, 返回默认图像和文本".format(e, pic_url))
image = url2pil("https://cube.elemecdn.com/0/00/c9431ffdc70fe9bb30e3697a3d306jpeg.jpeg")
img_arr = self.transform(image)
txt = "ALSC-寿司"
return img_arr,txt
try:
txt = self.url_txt_list[idx][1]
img_arr = self.transform(image)
except Exception as e:
print("获取文本或图像, 错误类型: {0}, 错误url: {1}, 返回默认图像和文本".format(e, pic_url))
image = url2pil("https://cube.elemecdn.com/0/00/c9431ffdc70fe9bb30e3697a3d306jpeg.jpeg")
img_arr = self.transform(image)
txt = "ALSC-寿司"
return img_arr,txt
return img_arr,txt