import os
import pickle
import numpy as np
def get_train_val_list(work_path):
'''
todo: get train val list
'''
image_list = []
mask_list = []
case_list = os.listdir(work_path)
case_list.sort()
for case in case_list:
case_path = os.path.join(work_path, case)
origin_case_path = os.path.join(case_path, 'origin.nii.gz')
mask_case_path = os.path.join(case_path, 'label.nii.gz')
image_list.append(origin_case_path)
mask_list.append(mask_case_path)
return image_list, mask_list
def write_pkl(save_path, image_list, mask_list):
'''
todo: write train_list and val_list as 8:2 to pkl
'''
with open(os.path.join(save_path, 'train.pkl'), "wb") as train_pkl, open(os.path.join(save_path, 'val.pkl'), "wb") as val_pkl:
data = dict()
for i in range(len(image_list)):
data['image_path'] = image_list[i]
data['mask_path'] = mask_list[i]
# split train and val data
if np.random.random()<0.8:
pickle.dump(data, train_pkl)
else:
pickle.dump(data, val_pkl)
def main():
root_path = '/data/xx/nii'
save_path = '/home/yy'
image_list, mask_list = get_train_val_list(root_path)
write_pkl(save_path,image_list, mask_list)
if __name__=="__main__":
main()
Python 随机划分训练和验证集 保存成pkl
最新推荐文章于 2024-05-15 14:32:39 发布