在keras中批量训练tfrecord数据(使用深度学习训练DMRI/DTI数据)

2 篇文章 0 订阅
2 篇文章 0 订阅

环境:pycharm
框架:tensorflow、keras
首先,我们都知道keras官网中有给出一个如何在keras使用tfrecord格式的mnist数据的实例(https://keras.io/examples/mnist_tfrecord/)。但是上面的代码并没有给出具体如何读取tfrecord文件,仅用一行read_data_sets()带过。其实,如果数据本身就是较小的图片格式,并不需要写进tfrecord中,那么就可以参考这位朋友的方法:

https://baijiahao.baidu.com/s?id=1628460932421002169&wfr=spider&for=pc 

笔者的数据是来自于弥散核磁共振成像预处理后生成的FA文件,是一个三维的以.nii.gz格式存储的脑图像,大小是145 * 170 * 145 * 1,如果用上面的方法,在读取数据中都将花费大量的时间。于是,笔者稍微对此修改了一下。

读写tfrecord部分,如果是其他数据,可忽略下面两段代码。

以下代码是在python中读取FA数据,并将其写入tfrecord。

savePath = '/home/wenjingxi/MRI/tfrecord'
dataPathList = glob('/media//MyFA/*dti_FA.nii.gz')
#文件缩小比例,可以选择不缩小,缩小原因是数据太大难以训练。
zoom_rate = 0.4		
if not os.path.isdir(savePath):
	os.makedirs(savePath)
#读取量表
def read_csv(filePath):
	csv_file = csv.reader(open(filePath))
	l = {}
	for r in csv_file:
    	l[r[0]] = list(map(int, r[6:12]))
	return l
t1 = time.time()
print('zoom rate: {}'.format(zoom_rate))

random.shuffle(dataPathList)
dataPathList_1 = dataPathList[0:len(dataPathList)]

labels = read_csv('label.csv')
for i in range(len(dataPathList_1)):
	savePath_t = os.path.join(savePath, 'dataset_{}.tfrecord'.format(i))
	writer = tf.python_io.TFRecordWriter(savePath_t)
	p_fa = dataPathList_1[i]
	data_fa, affine_fa = load_nifti(p_fa)
	data = data_fa
	print('data shape:{}'.format(data.shape))
	data = nd.interpolation.zoom(data, zoom_rate, prefilter=False)
	print("data shape:{}".format(data.shape))
	m = re.search('[0-9]{6}', p_fa)
	seq = m.group()
	print('seq:' + seq)
	print(labels[seq])
	img_raw = data.tobytes()
	print('cut last 4 num:', labels[seq])
	example = tf.train.Example(features=tf.train.Features(feature={"label": tf.train.Feature(
		int64_list=tf.train.Int64List(value=labels[seq])),'img_raw': tf.train.Feature(
		bytes_list=tf.train.BytesList(value=[img_raw]))
	}))
	n = writer.write(example.SerializeToString())
	print('No {} finish time cost:{} min'.format(i, (time.time() - t1) // 60))
	writer.close()

以下为读取tfrecord部分,如果是其他数据,可以采用其他简便的方式,

def _parse_function_60(example_proto):
	features = {"label": tf.FixedLenFeature([], tf.int64),"img_raw": tf.FixedLenFeature([], tf.string)}
	parsed_features = tf.parse_single_example(example_proto, features)
	img = tf.decode_raw(parsed_features['img_raw'], tf.float32)
	img = tf.reshape(img, [58, 70, 58, 1])
	img = tf.cast(img, tf.float32)
	print(parsed_features['label'])
	print('img shape~~~~~~~~~~~~~~~~:{}'.format(img.get_shape()))
	label = tf.cast(parsed_features['label'], tf.int64)
	print(label)
	label=tf.reshape(label,[1,2])
	return img, label

def load_data(sess,filename,batch_size,zoom_rate,shuffle_buffer=None):
	dataset = tf.data.TFRecordDataset(filename)
	if zoom_rate==60:
    	_parse_function=_parse_function_60
	dataset = dataset.map(_parse_function)
	dataset = dataset.repeat()
	dataset = dataset.batch(batch_size)
	iterator = dataset.make_initializable_iterator()
	next_batch = iterator.get_next()
	sess.run(iterator.initializer)
	return next_batch

def load_data_with_val(sess,batch_size,zoom_rate,shuffle_buffer=None,cross=0,brain_area=None,modal='flirt'):
	data_dir = '/media/tfrecords/3_flirt/0.4zoom_rate'
	file_list = []
	for i in range(3):
    	p = os.path.join(data_dir, 'dataset_{}.tfrecord'.format(i))
    	file_list.append(p)
	print(file_list)
	val_file=file_list[1]
	test_file=file_list[0]
	train_files=[file for file in file_list if file!=val_file and file!=test_file]
	next_batch_t=load_data(sess, filename=train_files, batch_size=batch_size, zoom_rate=zoom_rate, shuffle_buffer=shuffle_buffer)
	next_batch_v = load_data(sess, filename=val_file, batch_size=batch_size, zoom_rate=zoom_rate, shuffle_buffer=shuffle_buffer)
	next_batch_test = load_data(sess, filename=test_file, batch_size=batch_size, zoom_rate=zoom_rate, shuffle_buffer=shuffle_buffer)
	return next_batch_t,next_batch_v,next_batch_test

从这里开始的sess和读取tfrecord的sess是同一个,只不过为了批量读取,将dataset的tensor放到My_Costom_Generator中run了。

模型训练,将training_set和val_set两个tensor传给train()函数,用fit_generator()函数从重载函数My_Custom_Generator()中获取数据,要记得把session传过去。

def train(sess, training_set, val_set):
   	my_training_batch_generator = My_Custom_Generator(sess, training_set, self.batch_size, 855)
   	my_validation_batch_generator = My_Custom_Generator(sess, val_set, self.batch_size, 100)
   	model.fit_generator(generator=my_training_batch_generator,
                            steps_per_epoch=15,
                            epochs=10,
                            verbose=1,
                            validation_data=my_validation_batch_generator,
                            validation_steps = 3)

重载My_Custom_Generator(),用sess.run()将上面的training_set/val_set读出来。

class My_Custom_Generator(keras.utils.Sequence):
def __init__(self, sess, data_set, batch_size, dataset_size):
    self.sess = sess
    self.data_set = data_set
    self.batch_size = batch_size
    self.dataset_size = dataset_size
def __len__(self):
    return (np.ceil(self.dataset_size / float(self.batch_size))).astype(np.int)
def __getitem__(self, idx):
    data, label= self.sess.run(self.data_set)
    label = np.squeeze(label)
    return data, label

本文采用的是3D数据,在缩小到0.4倍后再放入模型的,模型部分与本文无关,读者可根据自己的需求自行编写。

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 24
    评论
评论 24
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值