import glob
import os
import shutil
import imageio
import lmdb
import numpy as np
from caffe2.proto import caffe2_pb2
from data.common import write_mat_to_npy
from caffe2.python import utils, core, workspace
# from data.proto import utils, tensor_pb2
data_root = '/data/zhwzhong/Data/CAVE/'
data_info = {
'train_data': os.path.join(data_root, 'train/'),
'val_data': os.path.join(data_root, 'val/'),
'test_data': os.path.join(data_root, 'test/')
}
data_size = {
'train_data': 341311488,
'val_data': 81264640,
'test_data': 81264640
}
def cave_mat():
pass
def write_to_lmdb():
if os.path.exists(os.path.join(data_root, 'lmdb')):
shutil.rmtree(os.path.join(data_root, 'lmdb'))
if os.path.exists(os.path.join(data_root, 'npy')):
shutil.rmtree(os.path.join(data_root, 'npy'))
os.makedirs(os.path.join(data_root, 'lmdb'), exist_ok=False)
os.makedirs(os.path.join(data_root, 'npy'), exist_ok=False)
for key, value in data_info.items():
num = 0
# env = lmdb.open(os.path.join(data_root, 'lmdb/{}'.format(key)), map_size=data_size[key])
db = core.C.create_db("minidb", os.path.join(data_root, 'lmdb/{}.minidb'.format(key)), core.C.Mode.write)
transaction = db.new_transaction()
for img_name in sorted(glob.glob(value + '**/*.png', recursive=True))[:5]:
print('Write {} to LMDB'.format(img_name))
tmp = np.array(imageio.imread(img_name))
save_npy_path = img_name.replace(data_root, data_root + 'npy/').replace(os.path.basename(img_name), '')
os.makedirs(save_npy_path, exist_ok=True)
write_mat_to_npy(tmp, save_npy_path + os.path.basename(img_name).replace('png', 'npy'))
tensor_proto = caffe2_pb2.TensorProtos()
# tensor_proto.protos.extend([utils.NumpyArrayToCaffe2Tensor(tmp), utils.NumpyArrayToCaffe2Tensor(tmp)])
tensor_proto.protos.extend([utils.NumpyArrayToCaffe2Tensor(tmp)])
transaction.put('train_%03d'.format(num).encode('ascii'), tensor_proto.SerializeToString())
num += 1
if __name__ == '__main__':
# write_to_lmdb()
# cave_mat()
net_proto = core.Net("example_reader")
dbreader = net_proto.CreateDB([], "dbreader", db="/data/zhwzhong/Data/CAVE/lmdb/train_data.minidb", db_type="minidb")
net_proto.TensorProtosDBInput([dbreader], ["X"], batch_size=5)
workspace.CreateNet(net_proto)
workspace.RunNet(net_proto.Proto().name)
print(workspace.FetchBlob("X").shape)
print(workspace.FetchBlob("X").shape)
print(workspace.FetchBlob("X").shape)
print(workspace.FetchBlob("X").shape)
print(workspace.FetchBlob("X").dtype)
caffe2 创建lmdb数据库
最新推荐文章于 2019-08-01 10:23:56 发布