Pytorch DATASETS & DATALOADERS

前言

本文参考了Pytorch官方文档:Tutorials > Datasets & DataLoaders

为了使数据集代码与模型训练代码分离,提升可读性和模块化,PyTorch 提供了两个数据原语:

  • torch.utils.data.DataLoader:存储样本及标签
  • torch.utils.data.Dataset:将一个迭代器包装在Dataset周围,以便更容易访问样本

官方解释

Dataset

如何创建自定义的Dataset

一个自定义的Dataset类必须实现三个函数:__init__ __len____getitem__

  • __init__:该函数在实例化 Dataset 对象时运行一次。
  • __len__:该函数返回数据集中的样本数。
  • __getitem__:该函数从给定索引idx的数据集中加载并返回样本。

用如下代码demo进行解释:

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    """
    FashionMNIST图像存储在img_dir目录中,它们的标签分别存储在 CSV 文件 annotations_file 中。
    """
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        """
        包含图像、注释文件及对两者的transform
        """
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        """
        根据索引,该函数识别图像在磁盘上的位置,使用 read_image 将其转换为张量,
        从 self.img_labels 中的 csv 数据中检索相应的标签,调用它们的转换函数,
        并返回张量图像 和元组中的相应标签。
        """
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

DataLoader

DataLoader类是从泛型类型的抽象基类typing.Generic继承的子类。从dir可以看到其只暴露出两个可供调用的类方法,但这两个方法并不常用。但其初始化的实例,具有很多可以调用的方法。

实战例子

加载lmdb数据,并batch输出(CLIP使用的图像数据)

from torch.utils.data import Dataset, DataLoader
import lmdb
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from PIL import Image
from io import BytesIO
import base64

class LMDBImgInferenceDataset(Dataset):
    def __init__(self, lmdb_path, base_idx):
        self.base_idx = base_idx
        assert os.path.isdir(lmdb_path), "The LMDB directory {} does not exist!".format(lmdb_path)
        self.env = lmdb.open(lmdb_path, readonly=True, create=False, lock=False, readahead=False, meminit=False)
        print("self.env.stat(): ", self.env.stat())
        self.txn = self.env.begin(buffers=True)
        self.transform = self._image_transform()
        super(LMDBImgInferenceDataset, self).__init__()
        
    def _image_transform(self, image_size=224):
        transform = Compose([
            Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            lambda image:image.convert('RGB'),
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), 
                      (0.26862954, 0.26130258, 0.27577711)),
        ])
        return transform
    
    def __len__(self):
        return self.env.stat()["entries"]+200
    
    def __del__(self):
        if hasattr(self, 'env'):
            print("self.env close!")
            self.env.close()
                
    def __getitem__(self, idx):
        idx+=self.base_idx
        for _ in range(2):
            img_mv = self.txn.get(str(idx).encode())
            if img_mv:
                image = Image.open(BytesIO(base64.urlsafe_b64decode(img_mv)))
                img_arr =  self.transform(image)
                return idx,img_arr
            else:
                idx=self.base_idx
                
lmdb_path = "/mnt/workspace/10_canyin_clip/data/lingshou/imgs"
lmdb_data = LMDBImgInferenceDataset(lmdb_path, 3100293)
lmdb_dataloader = DataLoader(lmdb_data, batch_size=64, shuffle=False)
data_iter = iter(lmdb_dataloader)

加载url-txt list数据

__getitem__不带异常处理的

from torch.utils.data import Dataset, DataLoader
import lmdb
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from PIL import Image
from io import BytesIO
import base64

import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from urllib import request
def url2pil(img_url):
    user_agent = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/66.0.3359.181 Safari/537.36'
    headers = {'user-agent': user_agent}
    req = request.Request(url=img_url, headers=headers)
    response = request.urlopen(req, timeout=30)
    img= Image.open(BytesIO(response.read())).convert('RGB')
    return img

class UrlImgInferenceDataset(Dataset):
    def __init__(self, url_txt_list):
        self.url_txt_list = url_txt_list
        self.transform = self._image_transform()
        super(UrlImgInferenceDataset, self).__init__()
        
    def _image_transform(self, image_size=224):
        transform = Compose([
            Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            lambda image:image.convert('RGB'),
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), 
                      (0.26862954, 0.26130258, 0.27577711)),
        ])
        return transform
    
    def __len__(self):
        return len(self.url_txt_list)
                
    def __getitem__(self, idx):
        image = url2pil(self.url_txt_list[idx][0])
        print("processing: {0}".format(idx))
        txt = self.url_txt_list[idx][1]
        img_arr =  self.transform(image)
        return img_arr,txt
                
url_txt_list = list(zip(url_list,txt_list))
urlimg_data = UrlImgInferenceDataset(url_txt_list)
lmdb_dataloader = DataLoader(urlimg_data, batch_size=20, shuffle=False)
data_iter = iter(lmdb_dataloader)

__getitem__带异常处理的

from torch.utils.data import Dataset, DataLoader
import lmdb
import os
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from PIL import Image
from io import BytesIO
import base64

import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from urllib import request
def url2pil(img_url):
    user_agent = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/66.0.3359.181 Safari/537.36'
    headers = {'user-agent': user_agent}
    req = request.Request(url=img_url, headers=headers)
    response = request.urlopen(req, timeout=30)
    img= Image.open(BytesIO(response.read())).convert('RGB')
    return img

class UrlImgInferenceDataset(Dataset):
    def __init__(self, url_txt_list):
        self.url_txt_list = url_txt_list
        self.transform = self._image_transform()
        super(UrlImgInferenceDataset, self).__init__()
        
    def _image_transform(self, image_size=224):
        transform = Compose([
            Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            lambda image:image.convert('RGB'),
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), 
                      (0.26862954, 0.26130258, 0.27577711)),
        ])
        return transform
    
    def __len__(self):
        return len(self.url_txt_list)
                
    def __getitem__(self, idx):
        need_loop = True
        pic_url = self.url_txt_list[idx][0]
        try:
            image = url2pil(pic_url)
        except Exception as e:
            print("加载url出错, 错误类型: {0}, 错误url: {1}, 返回默认图像和文本".format(e, pic_url))
            image = url2pil("https://cube.elemecdn.com/0/00/c9431ffdc70fe9bb30e3697a3d306jpeg.jpeg")
            img_arr =  self.transform(image)
            txt = "ALSC-寿司"
            return img_arr,txt
            
        try:
            txt = self.url_txt_list[idx][1]
            img_arr =  self.transform(image)
        except Exception as e:
            print("获取文本或图像, 错误类型: {0}, 错误url: {1}, 返回默认图像和文本".format(e, pic_url))
            image = url2pil("https://cube.elemecdn.com/0/00/c9431ffdc70fe9bb30e3697a3d306jpeg.jpeg")
            img_arr =  self.transform(image)
            txt = "ALSC-寿司"
            return img_arr,txt
        return img_arr,txt
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值