caffe2 创建lmdb数据库

import glob
import os
import shutil

import imageio
import lmdb
import numpy as np

from caffe2.proto import caffe2_pb2
from data.common import write_mat_to_npy
from caffe2.python import utils, core, workspace

# from data.proto import utils, tensor_pb2

data_root = '/data/zhwzhong/Data/CAVE/'

data_info = {
    'train_data': os.path.join(data_root, 'train/'),
    'val_data': os.path.join(data_root, 'val/'),
    'test_data': os.path.join(data_root, 'test/')
}

data_size =  {
    'train_data': 341311488,
    'val_data': 81264640,
    'test_data': 81264640
}

def cave_mat():
    pass

def write_to_lmdb():
    if os.path.exists(os.path.join(data_root, 'lmdb')):
        shutil.rmtree(os.path.join(data_root, 'lmdb'))
    if os.path.exists(os.path.join(data_root, 'npy')):
        shutil.rmtree(os.path.join(data_root, 'npy'))

    os.makedirs(os.path.join(data_root, 'lmdb'), exist_ok=False)
    os.makedirs(os.path.join(data_root, 'npy'), exist_ok=False)

    for key, value in data_info.items():
        num = 0
        # env = lmdb.open(os.path.join(data_root, 'lmdb/{}'.format(key)), map_size=data_size[key])
        db = core.C.create_db("minidb", os.path.join(data_root, 'lmdb/{}.minidb'.format(key)), core.C.Mode.write)
        transaction = db.new_transaction()
        for img_name in sorted(glob.glob(value + '**/*.png', recursive=True))[:5]:
            print('Write {} to LMDB'.format(img_name))

            tmp = np.array(imageio.imread(img_name))
            save_npy_path = img_name.replace(data_root, data_root + 'npy/').replace(os.path.basename(img_name), '')
            os.makedirs(save_npy_path, exist_ok=True)
            write_mat_to_npy(tmp, save_npy_path + os.path.basename(img_name).replace('png', 'npy'))
            tensor_proto = caffe2_pb2.TensorProtos()
            # tensor_proto.protos.extend([utils.NumpyArrayToCaffe2Tensor(tmp), utils.NumpyArrayToCaffe2Tensor(tmp)])
            tensor_proto.protos.extend([utils.NumpyArrayToCaffe2Tensor(tmp)])
            transaction.put('train_%03d'.format(num).encode('ascii'), tensor_proto.SerializeToString())
            num += 1


if __name__ == '__main__':
    # write_to_lmdb()
    # cave_mat()
    net_proto = core.Net("example_reader")
    dbreader = net_proto.CreateDB([], "dbreader", db="/data/zhwzhong/Data/CAVE/lmdb/train_data.minidb", db_type="minidb")

    net_proto.TensorProtosDBInput([dbreader], ["X"], batch_size=5)
    workspace.CreateNet(net_proto)
    workspace.RunNet(net_proto.Proto().name)
    print(workspace.FetchBlob("X").shape)
    print(workspace.FetchBlob("X").shape)
    print(workspace.FetchBlob("X").shape)
    print(workspace.FetchBlob("X").shape)
    print(workspace.FetchBlob("X").dtype)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值