本段代码是用于根据分割图像的标签部分统计原始数据强度范围,以便后续数据处理。
import os
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import numpy as np
from collections import Counter
def read_medical_image(filename):
image = sitk.ReadImage(filename)
array = sitk.GetArrayFromImage(image)
return np.float32(array)
def extract_interesting_regions(root, origin, label):
basename = os.path.basename(root)
ct_array = read_medical_image(os.path.join(root, basename + origin))
seg_array = read_medical_image(os.path.join(root, basename + label))
seg_mask = seg_array == 1
return ct_array[np.where(seg_mask > 0)]
def main(baseroot, origin, label):
result = []
for root, dirs, files in tqdm(os.walk(baseroot)):
for dir in dirs:
file_root = os.path.join(root, dir)
ct_aim = extract_interesting_regions(file_root, origin, label)
result.append(ct_aim)
return np.concatenate(result, axis=0)
if __name__ == '__main__':
file_root = 'D:/Datasets/'
origin_postfix = 'origin.png'
label_postfix = 'label.png'
save_name = 'test.png'
distribution = main(file_root, origin_postfix, label_postfix)
print("max:", distribution.max())
print("min: ", distribution.min())
print("var: ", distribution.var())
fig = plt.figure(figsize=(14, 8))
ax = fig.add_subplot(1, 1, 1)
sns.distplot(distribution, color='red', label="label", ax=ax)
new_ticks = np.linspace(-500, 1500, 21)
plt.xticks(new_ticks)
plt.legend()
plt.savefig(save_name)
plt.show()