caffe lmdb数据集写入与读取

单标签写入读取

#coding='utf-8'
import lmdb
import caffe
from matplotlib import pyplot as plt
import numpy as np

def write_lmdb(filename,X,y):
	N = len(y)
	map_size = X.nbytes * 10

	env = lmdb.open(filename,map_size=map_size)

	with env.begin(write = True) as txn:
		for i in range(N):
			datum = caffe.io.array_to_datum(X[i,:,:,:])
			datum.label = int(y[i])
			txn.put('{:0>10d}'.format(i).encode('ascii'),datum.SerializeToString())

def read_lmdb(filename):
	env = lmdb.open(filename, readonly=True)
	with env.begin(write=False) as txn:
		cursor = txn.cursor()
		datum = caffe.proto.caffe_pb2.Datum()
		i=0
		
		for key,value in cursor:
			i=i+1

		datum.ParseFromString(value)

		x = caffe.io.datum_to_array(datum)
		y = datum.label

		return x,y


def main():
	N = 1000
	x1 = np.random.randint(1,10,(N,3,32,32))
	y1 = np.zeros(N,dtype=np.int64)
	x2 = np.random.randint(1,10,(N,3,32,32)) + 10
	y2 = np.ones(N,dtype=np.int64)
	x3 = np.random.randint(1,10,(N,3,32,32)) + 20
	y3 = np.ones(N,dtype=np.int64)*2
	x4 = np.random.randint(1,10,(N,3,32,32)) + 30
	y4 = np.ones(N,dtype=np.int64)*3

	X = np.vstack((x1,x2,x3,x4))
	y = np.hstack((y1,y2,y3,y4))

	idx = np.arange(len(y))
	np.random.shuffle(idx)

	TRAIN_NUM = int(4*len(y)/5)

	write_lmdb("hbk_lmdb_train",X[idx[0:TRAIN_NUM],:,:,:],y[idx[0:TRAIN_NUM]])
	write_lmdb("hbk_lmdb_test",X[idx[0:TRAIN_NUM],:,:,:],y[idx[TRAIN_NUM:]])

	X1, y1 = read_lmdb("hbk_lmdb_train")

	print (X1.shape, y1)
	print (np.mean(X))



main()


多标签写入

import numpy as np
import lmdb
import caffe

def write_lmdb_data(filename, X):
    """
    filename: lmdb data dir
    x: data
    y: label
    """
    N = X.shape[0]
    map_size = X.nbytes * 10
    env = lmdb.open(filename, map_size=map_size)

    with env.begin(write=True) as txn:
        for i in range(N):
            datum = caffe.io.array_to_datum(X[i,:,:,:])
            txn.put('{:0>10d}'.format(i).encode('ascii'), datum.SerializeToString())




if __name__ == '__main__':
    N = 1000
    X1 = np.random.randint(1, 10, (N, 3, 32, 32))
    # 0,0,0,0,....
    y1 = np.zeros((N,10,1,1), dtype=np.int64)

    
    X2 = np.random.randint(1, 10, (N, 3, 32, 32))+10
    # 0,1,0,1,0,....
    y2 = np.zeros((N,10,1, 1), dtype=np.int64)
    y2[:,1,:,:] = 1; y2[:,3,:, :] = 1

    X3 = np.random.randint(1, 10, (N, 3, 32, 32))+20
    # 1,0,1,0,0,....
    y3 = np.zeros((N,10,1, 1), dtype=np.int64)
    y3[:,0,:, :] = 1; y3[:,2,:, :] = 1
    
    X4 = np.random.randint(1, 10, (N, 3, 32, 32))+30
    # 1,1,1,1,....
    y4 = np.ones((N,10,1,1), dtype=np.int64) 

    X = np.vstack((X1, X2, X3, X4))
    y = np.vstack((y1, y2, y3, y4))

    idx = np.arange(len(y))
    np.random.shuffle(idx)

    TRAIN_NUM = int(4*len(y)/5)

    write_lmdb_data("lmdb_train_data", X[idx[0:TRAIN_NUM], :, :, :])
    write_lmdb_data("lmdb_train_label", y[idx[0:TRAIN_NUM], :])
    write_lmdb_data("lmdb_test_data", X[idx[TRAIN_NUM:], :, :, :])
    write_lmdb_data("lmdb_test_label", y[idx[TRAIN_NUM:], :])
    print (np.mean(X))


在这里插入图片描述

说明,读取数据报错,尚未解决

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

佐倉

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值