创建数据集代码createmydata
1.第一段
warnings.filterwarnings("ignore")
flair_name = "_flair.nii.gz"
t1_name = "_t1.nii.gz"
t1ce_name = "_t1ce.nii.gz"
t2_name = "_t2.nii.gz"
mask_name = "_seg.nii.gz"
root_dir = "./BraTs_2019/BraTS2019_Training/HGG/"
out_root = "./data/BraTS_slice/"
outputFlair_path = out_root + "imgs_flair"
outputT1_path = out_root + "imgs_t1"
outputT2_path = out_root + "imgs_t2"
outputT1ce_path = out_root + "imgs_t1ce"
outputMaskWT_path = out_root + "masks"
outputMaskAll_path = out_root + "masks_all"
if not os.path.exists(outputFlair_path):
os.makedirs(outputFlair_path)
if not os.path.exists(outputT1_path):
os.makedirs(outputT1_path)
if not os.path.exists(outputT2_path):
os.makedirs(outputT2_path)
if not os.path.exists(outputT1ce_path):
os.makedirs(outputT1ce_path)
if not os.path.exists(outputMaskWT_path):
os.makedirs(outputMaskWT_path)
if not os.path.exists(outputMaskAll_path):
os.makedirs(outputMaskAll_path)
第一段解析
warnings.filterwarnings("ignore")
这行代码的租用是全文取消警告的输出
接下来小小的剖析下warnings.filterwarnings是干什么的:
def filterwarnings(action, message="", category=Warning, module="", lineno=0, append=False):
它的主要作用是在警告过滤器列表中插入一个条目(在前面)。
其中:
“action”–“error”、“ignore”、“always”、“default”、"module"或"once"之一
“message”–警告消息必须匹配的正则表达式
“category”–警告必须是其子类的类
“module”–模块名称必须匹配的正则表达式
“lineno”–整数行号,0与所有警告匹配
“append”–如果为true,则追加到筛选器列表
flair_name = "_flair.nii.gz"
t1_name = "_t1.nii.gz"
t1ce_name = "_t1ce.nii.gz"
t2_name = "_t2.nii.gz"
mask_name = "_seg.nii.gz"
_name代表了_modle.nii.gz,用来确定每个模态输入的地址
root_dir = "./BraTs_2019/BraTS2019_Training/HGG/"
out_root = "./data/BraTS_slice/"
outputFlair_path = out_root + "imgs_flair"
outputT1_path = out_root + "imgs_t1"
outputT2_path = out_root + "imgs_t2"
outputT1ce_path = out_root + "imgs_t1ce"
outputMaskWT_path = out_root + "masks"
outputMaskAll_path = out_root + "masks_all"
root_dir是原始数据地址
ot_root是数据输出的根地址,后面则是每个模态输出的地址
if not os.path.exists(outputFlair_path):
os.makedirs(outputFlair_path)
if not os.path.exists(outputT1_path):
os.makedirs(outputT1_path)
if not os.path.exists(outputT2_path):
os.makedirs(outputT2_path)
if not os.path.exists(outputT1ce_path):
os.makedirs(outputT1ce_path)
if not os.path.exists(outputMaskWT_path):
os.makedirs(outputMaskWT_path)
if not os.path.exists(outputMaskAll_path):
os.makedirs(outputMaskAll_path)
如果输出的路径文件不存在则创建
2.第二段
def file_name_path(file_dir):
"""
get root path,sub_dirs,all_sub_files
:param: file_dir:
:return: dir or file
"""
path_list = []
name_list = []
for root, dirs, files in os.walk(file_dir):
if len(dirs) and dir:
name_list = dirs
for f in files:
path = os.path.join(root,f)
path_list.append(path)
return name_list,path_list
train_hgg_list,train_hgg_path_list = file_name_path(root_dir)
print("train_hgg_list:",len(train_hgg_list),train_hgg_list,'\n')
all_list = train_hgg_list
print("\n all_list:",len(all_list),all_list)
第二段解析
def file_name_path(file_dir):
定义一个file_name_path方法,参数是输入的根路径root_dir
path_list = []
name_list = []
定义两个列表
for root, dirs, files in os.walk(file_dir):
if len(dirs) and dir:
name_list = dirs
for f in files:
path = os.path.join(root, f)
path_list.append(path)
return name_list,path_list
os.walk会返回一个文件的本身地址、包含的文件夹、包含的文件
因此这个函数方法的作用就是获取原始数据所包含的文件夹列表和文件列表(没有)
train_hgg_list,train_hgg_path_list = file_name_path(root_dir)
print("train_hgg_list:",len(train_hgg_list),train_hgg_list,'\n')
all_list = train_hgg_list
print("\n all_list:",len(all_list),all_list)
train_hgg_list就是原始数据根目录下的所有子文件夹
all_list和上面这个一样
举个栗子:第一个打印输出可能为
train_hgg_list: 10 ['121416', '120717', '120515', '120414', '120212', '120111', '120010', '119833', '119732', '119126']
3.第三段
def show_max_min(img_data):
min = np.min(img_data)
max = np.max(img_data)
print('min is {}'.format(min))
print('max is {}'.format(max))
print('dtype is {}'.format(img_data.dtype))
def show_info(np_data=[], name="nparry"):
# print("{} 's shape is {}".format(name,np_data.shape))
# print('{} dtype is {}'.format(name,np_data.dtype))
min_v = np.min(np_data)
max_v = np.max(np_data)
print('{} min value is {:.4f}'.format(name, min_v))
print('{} max value is {:.4f}'.format(name, max_v))
def normalize(slice):
b = np.percentile(slice, 99)
t = np.percentile(slice, 1)
slice = np.clip(slice, t, b)
image_nonzero = slice[np.nonzero(slice)]
if np.std(slice) == 0 or np.std(image_nonzero) == 0:
return slice
else:
tmp = (slice - np.mean(image_nonzero)) / np.std(image_nonzero)
return tmp
def move(mask, croph):
probe = 0
height, width = mask[0].shape
for probe in range(height // 2 - (croph // 2)):
bottom = height // 2 + (croph // 2) + probe
if np.max(mask[:, bottom, :]) == 0:
break
if probe == 0:
for probe in range(height // 2 - (croph // 2)):
up = height // 2 - (croph // 2) - probe
if np.max(mask[:, up, :]) == 0 or np.max(mask[:, up + croph, :]) == 1:
probe = 0 - probe
break
return probe
def crop_ceter(img, croph, cropw, move_value=0):
height, width = img[0].shape
starth = height // 2 - (croph // 2) + move_value
startw = width // 2 - (cropw // 2)
return img[:, starth:starth + croph, startw:startw + cropw]
第三段解析
def show_max_min(img_data):
min = np.min(img_data)
max = np.max(img_data)
print('min is {}'.format(min))
print('max is {}'.format(max))
print('dtype is {}'.format(img_data.dtype))
获得数据的最大值、最小值和数据类型并打印
def show_info(np_data=[], name="nparry"):
# print("{} 's shape is {}".format(name,np_data.shape))
# print('{} dtype is {}'.format(name,np_data.dtype))
min_v = np.min(np_data)
max_v = np.max(np_data)
print('{} min value is {:.4f}'.format(name, min_v))
print('{} max value is {:.4f}'.format(name, max_v))
获得数据中的最大值和最小值并打印
def normalize(slice):
b = np.percentile(slice, 99)
t = np.percentile(slice, 1)
slice = np.clip(slice, t, b)
image_nonzero = slice[np.nonzero(slice)]
if np.std(slice) == 0 or np.std(image_nonzero) == 0:
return slice
else:
tmp = (slice - np.mean(image_nonzero)) / np.std(image_nonzero)
return tmp
np.percentile用来计算一个数组0-100分位的数,如果是0那就计算的是最小值,是100计算的就是最大值
np.clip限定最小值和最大值,举个栗子:
a = np.array([1, 2, 3, 4, 5])
print(np.clip(a, 2, 4))
输出结果为:
[2 2 3 4 4]
上述代码首先使用np.percentile计算切片的上下分位数,分别保存在变量 b 和 t 中。然后使用np.clip函数将切片中的值限制在 t 和 b 之间。
使用np.nonzero获取切片中非零元素的索引并将这些非零元素保存在image_nonzero中
使用np.std检查切片的标准差是否为零,以及image_nonzero的标准差是否为零。如果其中任何一个标准差为零,函数将直接返回原始切片。否则,进行归一化后再返回,归一化通过将切片减去image_nonzero的均值并除以image_nonzero的标准差来实现
def move(mask, croph):
probe = 0
height, width = mask[0].shape
for probe in range(height // 2 - (croph // 2)):
bottom = height // 2 + (croph // 2) + probe
if np.max(mask[:, bottom, :]) == 0:
break
if probe == 0:
for probe in range(height // 2 - (croph // 2)):
up = height // 2 - (croph // 2) - probe
if np.max(mask[:, up, :]) == 0 or np.max(mask[:, up + croph, :]) == 1:
probe = 0 - probe
break
return probe
'//'是地板除,也就是除完后向下取整
这个函数的目的是根据图像mask的情况来移动它的位置。首先向下移动mask,直到找到一个底部位置的掩膜全为0。如果底部位置的mask有1,则向上移动掩膜,直到找到一个顶部位置的掩膜全为0,或者顶部位置加上 croph 的位置的掩膜都有1。最终,函数返回移动的距离 probe。(没看太懂)
def crop_ceter(img, croph, cropw, move_value=0):
height, width = img[0].shape
starth = height // 2 - (croph // 2) + move_value
startw = width // 2 - (cropw // 2)
return img[:, starth:starth + croph, startw:startw + cropw]
剪裁
4.第四段
for subsetindex in range(len(all_list)):
flair_image = root_dir + all_list[subsetindex] + '/' + all_list[subsetindex] + flair_name
t1_image = root_dir + all_list[subsetindex] + '/' + all_list[subsetindex] + t1_name
t2_image = root_dir + all_list[subsetindex] + '/' + all_list[subsetindex] + t2_name
t1ce_image = root_dir + all_list[subsetindex] + '/' + all_list[subsetindex] + t1ce_name
mask_image = root_dir + all_list[subsetindex] + '/' + all_list[subsetindex] + mask_name
print(subsetindex)
flair_src = sitk.ReadImage(flair_image, sitk.sitkInt16)
t1_src = sitk.ReadImage(t1_image, sitk.sitkInt16)
t1ce_src = sitk.ReadImage(t1ce_image, sitk.sitkInt16)
t2_src = sitk.ReadImage(t2_image, sitk.sitkInt16)
mask = sitk.ReadImage(mask_image, sitk.sitkUInt8)
flair_array = sitk.GetArrayFromImage(flair_src)
t1_array = sitk.GetArrayFromImage(t1_src)
t1ce_array = sitk.GetArrayFromImage(t1ce_src)
t2_array = sitk.GetArrayFromImage(t2_src)
mask_array = sitk.GetArrayFromImage(mask)
# 3. normalization
flair_array_nor = normalize(flair_array)
t1_array_nor = normalize(t1_array)
t1ce_array_nor = normalize(t1ce_array)
t2_array_nor = normalize(t2_array)
# 4. cropping
move_value = move(mask_array, 160)
will_up = False
if move_value != 0:
print((all_list[subsetindex]))
print("move value: ", move_value)
will_up = True
flair_crop = crop_ceter(flair_array_nor, 160, 160, move_value)
t1_crop = crop_ceter(t1_array_nor, 160, 160, move_value)
t1ce_crop = crop_ceter(t1ce_array_nor, 160, 160, move_value)
t2_crop = crop_ceter(t2_array_nor, 160, 160, move_value)
mask_crop = crop_ceter(mask_array, 160, 160, move_value)
for n_slice in range(mask_crop.shape[0]):
mask_np = mask_crop[n_slice, :, :]
mask_np_wt = mask_np.copy()
mask_np_all = mask_np.copy()
all_label3 = np.empty((3, 160, 160), np.uint8)
if np.max(mask_np) != 0:
# for one class
mask_np_wt[mask_np_wt > 1] = 1
# for three classes
WT_Label = mask_np_all.copy()
WT_Label[mask_np_all == 1] = 1.
WT_Label[mask_np_all == 2] = 1.
WT_Label[mask_np_all == 4] = 1.
TC_Label = mask_np_all.copy()
TC_Label[mask_np_all == 1] = 1.
TC_Label[mask_np_all == 2] = 0.
TC_Label[mask_np_all == 4] = 1.
ET_Label = mask_np_all.copy()
ET_Label[mask_np_all == 1] = 0.
ET_Label[mask_np_all == 2] = 0.
ET_Label[mask_np_all == 4] = 1.
all_label3[0, :, :] = WT_Label
all_label3[1, :, :] = TC_Label
all_label3[2, :, :] = ET_Label
else:
all_label3[0, :, :] = mask_np_all
all_label3[1, :, :] = mask_np_all
all_label3[2, :, :] = mask_np_all
# flair
flair_np = flair_crop[n_slice, :, :]
flair_np = flair_np.astype(np.float32)
# t1
t1_np = t1_crop[n_slice, :, :]
t1_np = t1_np.astype(np.float32)
# t2
t2_np = t2_crop[n_slice, :, :]
t2_np = t2_np.astype(np.float32)
# t1ce
t1ce_np = t1ce_crop[n_slice, :, :]
t1ce_np = t1ce_np.astype(np.float32)
slice_num = n_slice + 1
if len(str(slice_num)) == 1:
new_slice_num = '00' + str(slice_num)
elif len(str(slice_num)) == 2:
new_slice_num = '0' + str(slice_num)
else:
new_slice_num = str(slice_num)
flair_imagepath = outputFlair_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
t1_imagepath = outputT1_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
t2_imagepath = outputT2_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
t1ce_imagepath = outputT1ce_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
maskpath_wt = outputMaskWT_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
maskpath_all = outputMaskAll_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
np.save(flair_imagepath, flair_np)
np.save(t1_imagepath, t1_np)
np.save(t2_imagepath, t2_np)
np.save(t1ce_imagepath, t1ce_np)
np.save(maskpath_wt, mask_np_wt)
np.save(maskpath_all, all_label3)
print("Done!")
第四段解析
for subsetindex in range(len(all_list)):
遍历原始数据中的文件夹
flair_image = root_dir + all_list[subsetindex] + '/' + all_list[subsetindex] + flair_name
t1_image = root_dir + all_list[subsetindex] + '/' + all_list[subsetindex] + t1_name
t2_image = root_dir + all_list[subsetindex] + '/' + all_list[subsetindex] + t2_name
t1ce_image = root_dir + all_list[subsetindex] + '/' + all_list[subsetindex] + t1ce_name
mask_image = root_dir + all_list[subsetindex] + '/' + all_list[subsetindex] + mask_name
print(subsetindex)
model_img是各个模态的原始文件路径
flair_src = sitk.ReadImage(flair_image, sitk.sitkInt16)
t1_src = sitk.ReadImage(t1_image, sitk.sitkInt16)
t1ce_src = sitk.ReadImage(t1ce_image, sitk.sitkInt16)
t2_src = sitk.ReadImage(t2_image, sitk.sitkInt16)
mask = sitk.ReadImage(mask_image, sitk.sitkUInt8)
flair_array = sitk.GetArrayFromImage(flair_src)
t1_array = sitk.GetArrayFromImage(t1_src)
t1ce_array = sitk.GetArrayFromImage(t1ce_src)
t2_array = sitk.GetArrayFromImage(t2_src)
mask_array = sitk.GetArrayFromImage(mask)
model_src是SimpleITK图像对象,model_array为NumPy数组
# 3. normalization
flair_array_nor = normalize(flair_array)
t1_array_nor = normalize(t1_array)
t1ce_array_nor = normalize(t1ce_array)
t2_array_nor = normalize(t2_array)
model_array_nor标准化
# 4. cropping
move_value = move(mask_array, 160)
will_up = False
if move_value != 0:
print((all_list[subsetindex]))
print("move value: ", move_value)
will_up = True
flair_crop = crop_ceter(flair_array_nor, 160, 160, move_value)
t1_crop = crop_ceter(t1_array_nor, 160, 160, move_value)
t1ce_crop = crop_ceter(t1ce_array_nor, 160, 160, move_value)
t2_crop = crop_ceter(t2_array_nor, 160, 160, move_value)
mask_crop = crop_ceter(mask_array, 160, 160, move_value)
model_crop裁剪移位
for n_slice in range(mask_crop.shape[0]):
mask_np = mask_crop[n_slice, :, :]
mask_np_wt = mask_np.copy()
mask_np_all = mask_np.copy()
all_label3 = np.empty((3, 160, 160), np.uint8)
if np.max(mask_np) != 0:
# for one class
mask_np_wt[mask_np_wt > 1] = 1
# for three classes
WT_Label = mask_np_all.copy()
WT_Label[mask_np_all == 1] = 1.
WT_Label[mask_np_all == 2] = 1.
WT_Label[mask_np_all == 4] = 1.
TC_Label = mask_np_all.copy()
TC_Label[mask_np_all == 1] = 1.
TC_Label[mask_np_all == 2] = 0.
TC_Label[mask_np_all == 4] = 1.
ET_Label = mask_np_all.copy()
ET_Label[mask_np_all == 1] = 0.
ET_Label[mask_np_all == 2] = 0.
ET_Label[mask_np_all == 4] = 1.
all_label3[0, :, :] = WT_Label
all_label3[1, :, :] = TC_Label
all_label3[2, :, :] = ET_Label
else:
all_label3[0, :, :] = mask_np_all
all_label3[1, :, :] = mask_np_all
all_label3[2, :, :] = mask_np_all
# flair
flair_np = flair_crop[n_slice, :, :]
flair_np = flair_np.astype(np.float32)
# t1
t1_np = t1_crop[n_slice, :, :]
t1_np = t1_np.astype(np.float32)
# t2
t2_np = t2_crop[n_slice, :, :]
t2_np = t2_np.astype(np.float32)
# t1ce
t1ce_np = t1ce_crop[n_slice, :, :]
t1ce_np = t1ce_np.astype(np.float32)
slice_num = n_slice + 1
if len(str(slice_num)) == 1:
new_slice_num = '00' + str(slice_num)
elif len(str(slice_num)) == 2:
new_slice_num = '0' + str(slice_num)
else:
new_slice_num = str(slice_num)
flair_imagepath = outputFlair_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
t1_imagepath = outputT1_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
t2_imagepath = outputT2_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
t1ce_imagepath = outputT1ce_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
maskpath_wt = outputMaskWT_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
maskpath_all = outputMaskAll_path + "/" + (all_list[subsetindex]) + "_" + str(new_slice_num) + ".npy"
np.save(flair_imagepath, flair_np)
np.save(t1_imagepath, t1_np)
np.save(t2_imagepath, t2_np)
np.save(t1ce_imagepath, t1ce_np)
np.save(maskpath_wt, mask_np_wt)
np.save(maskpath_all, all_label3)
遍历mask中的每个切片
对于单通道(二分类),如果mask中有值则全赋1
对于三通道(四分类),如果mask中有值则赋为111/101/001
model_np每个模态每一层切片的Numpy数组
new_slice_num指第几个切片,后续命名用
model_imagepath是每个切片的输出路径,maskpath_wt是单通道的mask路径,maskpath_all是多通道的mask路径
最后通过np.save保存起来
print("Done!")
全部执行完毕打印“Done!”