制作自己的LMDB数据

前言

记录下pytorch里如何使用lmdb的code,自用

制作部分的Code

code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签

import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re

def checkImageIsValid(imageBin):
    if imageBin is None:
        return False
    imageBuf = np.fromstring(imageBin, dtype=np.uint8)
    img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
    imgH, imgW = img.shape[0], img.shape[1]
    if imgH * imgW == 0:
        return False
    return True


def writeCache(env, cache):
    with env.begin(write=True) as txn:
        for k, v in cache.items():
            txn.put(k.encode(), v)


def _is_difficult(word):
    assert isinstance(word, str)
    return not re.match('^[\w]+$', word)


def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
    """
    Create LMDB dataset for CRNN training.
    ARGS:
            outputPath        : LMDB output path
            imagePathList : list of image path
            labelList         : list of corresponding groundtruth texts
            lexiconList     : (optional) list of lexicon lists
            checkValid        : if true, check the validity of every image
    """
    assert(len(imagePathList) == len(labelList))
    nSamples = len(imagePathList)
    env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB
    cache = {}
    cnt = 1
    for i in range(nSamples):
        imagePath = imagePathList[i]
        label = labelList[i]
        if len(label) == 0:
            continue
        if not os.path.exists(imagePath):
            print('%s does not exist' % imagePath)
            continue
        with open(imagePath, 'rb') as f:
            imageBin = f.read()
        if checkValid:
            if not checkImageIsValid(imageBin):
                print('%s is not a valid image' % imagePath)
                continue

        #数据库中都是二进制数据
        imageKey = 'image-%09d' % cnt#9位数不足填零
        labelKey = 'label-%09d' % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label.encode()
        if lexiconList:
            lexiconKey = 'lexicon-%09d' % cnt
            cache[lexiconKey] = ' '.join(lexiconList[i])
        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
            print('Written %d / %d' % (cnt, nSamples))
        cnt += 1
    nSamples = cnt-1
    cache['num-samples'] = str(nSamples).encode()
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)

def get_sample_list(txt_path:str):
        with open(txt_path,'r') as fr:
                jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
        txt_content_list=[]
        for jpg in jpg_list:
                label_path=jpg.replace('.jpg','.txt')

                with open(label_path,'r') as fr:
                        try:
                                str_tmp=fr.readline()
                        except UnicodeDecodeError as e:
                                print(label_path)
                                raise(e)
                        txt_content_list.append(str_tmp.strip())


        return jpg_list,txt_content_list

if __name__ == "__main__":
    txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
    lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
    imagePathList,labelList=get_sample_list(txt_path)
    createDataset(lmdb_output_path, imagePathList, labelList)

读取部分

这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__

from __future__ import absolute_import

# import sys
# sys.path.append('./')

import os
# import moxing as mox

import pickle
from tqdm import tqdm
from PIL import Image, ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six

import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms

from lib.utils.labelmaps import get_vocabulary, labels2strs
from lib.utils import to_numpy

ImageFile.LOAD_TRUNCATED_IMAGES = True


from config import get_args
global_args = get_args(sys.argv[1:])

if global_args.run_on_remote:
    import moxing as mox
    #moxing是一个分布式的框架 跳过

class LmdbDataset(data.Dataset):
    def __init__(self, root, voc_type, max_len, num_samples, transform=None):
        super(LmdbDataset, self).__init__()

        if global_args.run_on_remote:
            dataset_name = os.path.basename(root)
            data_cache_url = "/cache/%s" % dataset_name
            if not os.path.exists(data_cache_url):
                os.makedirs(data_cache_url)
            if mox.file.exists(root):
                mox.file.copy_parallel(root, data_cache_url)
            else:
                raise ValueError("%s not exists!" % root)
            
            self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
        else:
            self.env = lmdb.open(root, max_readers=32, readonly=True)

        assert self.env is not None, "cannot create lmdb from %s" % root
        self.txn = self.env.begin()

        self.voc_type = voc_type
        self.transform = transform
        self.max_len = max_len
        self.nSamples = int(self.txn.get(b"num-samples"))
        self.nSamples = min(self.nSamples, num_samples)

        assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']
        self.EOS = 'EOS'
        self.PADDING = 'PADDING'
        self.UNKNOWN = 'UNKNOWN'
        self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
        self.char2id = dict(zip(self.voc, range(len(self.voc))))
        self.id2char = dict(zip(range(len(self.voc)), self.voc))

        self.rec_num_classes = len(self.voc)
        self.lowercase = (voc_type == 'LOWERCASE')

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index += 1
        img_key = b'image-%09d' % index
        imgbuf = self.txn.get(img_key)

        #由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象
        buf = six.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        try:
            img = Image.open(buf).convert('RGB')
            # img = Image.open(buf).convert('L')
            # img = img.convert('RGB')
        except IOError:
            print('Corrupted image for %d' % index)
            return self[index + 1]

        # reconition labels
        label_key = b'label-%09d' % index
        word = self.txn.get(label_key).decode()
        if self.lowercase:
            word = word.lower()
        ## fill with the padding token
        label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)
        label_list = []
        for char in word:
            if char in self.char2id:
                label_list.append(self.char2id[char])
            else:
                ## add the unknown token
                print('{0} is out of vocabulary.'.format(char))
                label_list.append(self.char2id[self.UNKNOWN])
        ## add a stop token
        label_list = label_list + [self.char2id[self.EOS]]
        assert len(label_list) <= self.max_len
        label[:len(label_list)] = np.array(label_list)

        if len(label) <= 0:
            return self[index + 1]

        # label length
        label_len = len(label_list)

        if self.transform is not None:
            img = self.transform(img)
        return img, label, label_len
  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值