在pytorch中如何使用lmdb

本文详细介绍了如何使用lmdb库将ImageNet图片数据集转换为高效的数据存储格式,并提供folder2lmdb.py的源码和关键步骤。通过修改文件,用户可以轻松地将图片文件夹转换为lmdb格式,便于在PyTorch中快速加载数据。
摘要由CSDN通过智能技术生成

总述

1、lmdb使用源码github链接:pytorch_lmdb_imagenet
2、使用方法:修改folder2lmdb.py文件即可
①先修改folder2lmdb函数,将图片文件夹转化为lmdb文件;
②再在实际实验中,修改 ImageFolderLMDB类,将现成的lmdb文件转化为dataset,方便后续读取。

folder2lmdb.py完整源码及具体修改如下:

  • import部分
import os
import os.path as osp
from PIL import Image
import six
import lmdb
import pickle
import numpy as np

import torch.utils.data as data
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
  • 从定义的Dataset类,用于将图片数据集转换为lmdb后,将lmdb文件通过ImageFolderLMDB转化为Dataset,方便后续读取,修改__getitem__函数即修改读取内容。注意:

输入地址应为包含了一个data.mdb和一个lock.mdb文件的文件夹名,若senti文件夹下有data.mdb、lock.mdb以及其他文件,则输入“senti\”即可。
本页代码中生成的lmdb文件的文件名格式为train.lmdb和train.lmdb.lock,只需分别把这两个文件的文件名改为data.mdb和lock.mdb即可。


def loads_data(buf):
    """
    Args:
        buf: the output of `dumps`.
    """
    return pickle.loads(buf)


class ImageFolderLMDB(data.Dataset):
    def __init__(self, db_path, transform=None, target_transform=None):
        self.db_path = db_path
        self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
                             readonly=True, lock=False,
                             readahead=False, meminit=False)
        with self.env.begin(write=False) as txn:
            self.length = loads_data(txn.get(b'__len__'))
            self.keys = loads_data(txn.get(b'__keys__'))

        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        env = self.env
        with env.begin(write=False) as txn:
            byteflow = txn.get(self.keys[index])

        unpacked = loads_data(byteflow)

        # load img
        imgbuf = unpacked[0]
        buf = six.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        img = Image.open(buf).convert('RGB')

        # load label
        target = unpacked[1]

        if self.transform is not None:
            img = self.transform(img)

        im2arr = np.array(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        # return img, target
        return im2arr, target

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__ + ' (' + self.db_path + ')'
def raw_reader(path):
    with open(path, 'rb') as f:
        bin_data = f.read()
    return bin_data


def dumps_data(obj):
    """
    Serialize an object.
    Returns:
        Implementation-dependent bytes-like object
    """
    return pickle.dumps(obj)
  • 将图片文件夹转化为lmdb文件的函数,输入图片所在文件夹,在该文件夹下输出.lmdb和.lmdb.lock文件。
def folder2lmdb(dpath, name="train", write_frequency=5000):
    directory = osp.expanduser(osp.join(dpath, name))
    print("Loading dataset from %s" % directory)
    dataset = ImageFolder(directory, loader=raw_reader)
    data_loader = DataLoader(dataset, num_workers=16, collate_fn=lambda x: x)

    lmdb_path = osp.join(dpath, "%s.lmdb" % name)
    isdir = os.path.isdir(lmdb_path)

    print("Generate LMDB to %s" % lmdb_path)
    db = lmdb.open(lmdb_path, subdir=isdir,
                   map_size=1099511627776 * 2, readonly=False,
                   meminit=False, map_async=True)

    txn = db.begin(write=True)
    for idx, data in enumerate(data_loader):
        image, label = data[0]

        txn.put(u'{}'.format(idx).encode('ascii'), dumps_data((image, label)))
        if idx % write_frequency == 0:
            print("[%d/%d]" % (idx, len(data_loader)))
            txn.commit()
            txn = db.begin(write=True)

    # finish iterating through dataset
    txn.commit()
    keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
    with db.begin(write=True) as txn:
        txn.put(b'__keys__', dumps_data(keys))
        txn.put(b'__len__', dumps_data(len(keys)))

    print("Flushing database ...")
    db.sync()
    db.close()
  • 调用folder2lmdb函数,注意:

若图片文件地址为"dataset/image/xxx.jpg",则此处的输入变量应为"dataset/",因为folder2lmdb函数在读取图片时,会把image作为图片的target存储,若输入变量写成"dataset/image/",会报错。

if __name__ == "__main__":
    # generate lmdb
    folder2lmdb("/home/jiang/dataset/imagenet/", name="train")
    folder2lmdb("/home/jiang/dataset/imagenet/", name="val")
抱歉,我无法回答关于"pytorch .mdb"的问题,因为在提供的引用内容没有提到与".mdb"相关的信息。请提供更多上下文或详细信息,以便我能够更好地回答您的问题。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [pytorch读取lmdb文件报错,lmdb.InvalidParameterError:解决](https://blog.csdn.net/lxb206/article/details/125646064)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *2* [Ubuntu16+CUDA8+Caffe+Tensorflow+Pytorch+cuDNN6+Matlab2015b+Opencv3(一篇就够)](https://blog.csdn.net/u011681952/article/details/84579954)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *3* [1MDB_commnets_classification_:IMDB影评类基于pytorch框架acc 97%](https://download.csdn.net/download/weixin_42175971/16126935)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值