[Keras] 3D U-Net源码解析之generator.py






def get_training_and_validation_generators(data_file, batch_size, n_labels, training_keys_file, validation_keys_file,
                                           data_split=0.8, overwrite=False, labels=None, augment=False,
                                           augment_flip=True, augment_distortion_factor=0.25, patch_shape=None,
                                           validation_patch_overlap=0, training_patch_start_offset=None,
                                           validation_batch_size=None, skip_blank=True, permute=False):
    Creates the training and validation generators that can be used when training the model.
    :param skip_blank: If True, any blank (all-zero) label images/patches will be skipped by the data generator.
    :param validation_batch_size: Batch size for the validation data.
    :param training_patch_start_offset: Tuple of length 3 containing integer values. Training data will randomly be
    offset by a number of pixels between (0, 0, 0) and the given tuple. (default is None)
    :param validation_patch_overlap: Number of pixels/voxels that will be overlapped in the validation data. (requires
    patch_shape to not be None)
    :param patch_shape: Shape of the data to return with the generator. If None, the whole image will be returned.
    (default is None)
    :param augment_flip: if True and augment is True, then the data will be randomly flipped along the x, y and z axis
    :param augment_distortion_factor: if augment is True, this determines the standard deviation from the original
    that the data will be distorted (in a stretching or shrinking fashion). Set to None, False, or 0 to prevent the
    augmentation from distorting the data in this way.
    :param augment: If True, training data will be distorted on the fly so as to avoid over-fitting.
    :param labels: List or tuple containing the ordered label values in the image files. The length of the list or tuple
    should be equal to the n_labels value.
    Example: (10, 25, 50)
    The data generator would then return binary truth arrays representing the labels 10, 25, and 30 in that order.
    :param data_file: hdf5 file to load the data from.
    :param batch_size: Size of the batches that the training generator will provide.
    :param n_labels: Number of binary labels.
    :param training_keys_file: Pickle file where the index locations of the training data will be stored.
    :param validation_keys_file: Pickle file where the index locations of the validation data will be stored.
    :param data_split: How the training and validation data will be split. 0 means all the data will be used for
    validation and none of it will be used for training. 1 means that all the data will be used for training and none
    will be used for validation. Default is 0.8 or 80%.
    :param overwrite: If set to True, previous files will be overwritten. The default mode is false, so that the
    training and validation splits won't be overwritten when rerunning model training.
    :param permute: will randomly permute the data (data must be 3D cube)
    :return: Training data generator, validation data generator, number of training steps, number of validation steps
    if not validation_batch_size:
        validation_batch_size = batch_size

    training_list, validation_list = get_validation_split(data_file,

    training_generator = data_generator(data_file, training_list,
    validation_generator = data_generator(data_file, validation_list,

    # Set the number of training and testing samples per epoch correctly
    num_training_steps = get_number_of_steps(get_number_of_patches(data_file, training_list, patch_shape,
                                                                   patch_overlap=0), batch_size)
    print("Number of training steps: ", num_training_steps)

    num_validation_steps = get_number_of_steps(get_number_of_patches(data_file, validation_list, patch_shape,
    print("Number of validation steps: ", num_validation_steps)

    return training_generator, validation_generator, num_training_steps, num_validation_steps



  1. skip_blank = config[“skip_blank”] = True # if True, then patches without any target will be skipped

  2. validation_batch_size = config[“validation_batch_size”] = 12

  3. training_patch_start_offset =config[“training_patch_start_offset”] = = (16, 16, 16) # randomly offset the first patch index by up to this offset
    长度为3的元组,包含整数值。训练数据将随机偏移(0,0,0)与给定元组之间的许多像素。 (默认为无)

  4. validation_patch_overlap = config[“validation_patch_overlap”] = 0 # if > 0, during training, validation patches will be overlapping
    在验证数据中将重叠的像素/体素数。 (要求patch_shape不为None)

  5. patch_shape = config[“patch_shape”] = (64, 64, 64) # switch to None to train on the whole image

  6. augment_flip = config[“flip”] = False # augments the data by randomly flipping an axis during

  7. augment_distortion_factor = config[“distort”] = None # switch to None if you want no distortion

  8. augment =config[“augment”] = config[“flip”] or config[“distort”] = False

  9. labels = config[“labels”] = (1, 2, 4) # the label numbers on the input image
    Example: (10, 25, 50)
    如BraTS参考文件所述,label包括GD增强型肿瘤(ET-标签4),肿瘤周围水肿(ED-标签2)以及坏死和非增强型肿瘤核心(NCR / NET-标签1)

  10. data_file = data_file_opened

  11. batch_size = config[“batch_size”] = 6
    Batch就是每次送入网络中训练的一部分数据,而Batch Size就是每个batch中训练样本的数量

  12. n_labels = len(config[“labels”]) = 3

  13. training_keys_file = config[“training_file”] = os.path.abspath(“training_ids.pkl”)

  14. validation_keys_file = config[“validation_file”] = os.path.abspath(“validation_ids.pkl”)

  15. data_split = config[“validation_split”] = 0.8
    如何分割训练和验证数据。 0表示所有数据都将用于验证,而所有数据都不会用于训练。 1表示所有数据都将用于训练,而没有数据将用于验证。 默认值为0.8或80%。

  16. overwrite = True # If True, will previous files. If False, will use previously written files.
    如果设置为True,则先前的文件将被覆盖。 默认模式为false,因此重新运行模型训练时,训练和验证拆分不会被覆盖。

  17. permute =c onfig[“permute”] = True # data shape must be a cube. Augments the data by permuting in various directions

  18. 返回


if not validation_batch_size:
    validation_batch_size = batch_size



training_list, validation_list = get_validation_split(data_file,



training_generator = data_generator(data_file, training_list,
validation_generator = data_generator(data_file, validation_list,



# Set the number of training and testing samples per epoch correctly
num_training_steps = get_number_of_steps(get_number_of_patches(data_file, training_list, patch_shape,
                                                               patch_overlap=0), batch_size)
print("Number of training steps: ", num_training_steps)

num_validation_steps = get_number_of_steps(get_number_of_patches(data_file, validation_list, patch_shape,
print("Number of validation steps: ", num_validation_steps)



return training_generator, validation_generator, num_training_steps, num_validation_steps




def get_validation_split(data_file, training_file, validation_file, data_split=0.8, overwrite=False):
    Splits the data into the training and validation indices list.
    :param data_file: pytables hdf5 data file
    :param training_file:
    :param validation_file:
    :param data_split:
    :param overwrite:
    if overwrite or not os.path.exists(training_file):
        print("Creating validation split...")
        nb_samples = data_file.root.data.shape[0]
        sample_list = list(range(nb_samples))
        training_list, validation_list = split_list(sample_list, split=data_split)
        pickle_dump(training_list, training_file)
        pickle_dump(validation_list, validation_file)
        return training_list, validation_list
        print("Loading previous validation split...")
        return pickle_load(training_file), pickle_load(validation_file)



def split_list(input_list, split=0.8, shuffle_list=True):
    if shuffle_list:
    n_training = int(len(input_list) * split)
    training = input_list[:n_training]
    testing = input_list[n_training:]
    return training, testing



def pickle_dump(item, out_file):
    with open(out_file, "wb") as opened_file:
        pickle.dump(item, opened_file)


pickle.dump(obj, file, [,protocol])



def data_generator(data_file, index_list, batch_size=1, n_labels=1, labels=None, augment=False, augment_flip=True,
                   augment_distortion_factor=0.25, patch_shape=None, patch_overlap=0, patch_start_offset=None,
                   shuffle_index_list=True, skip_blank=True, permute=False):
    orig_index_list = index_list
    while True:
        x_list = list()
        y_list = list()
        if patch_shape:
            index_list = create_patch_index_list(orig_index_list, data_file.root.data.shape[-3:], patch_shape,
                                                 patch_overlap, patch_start_offset)
            index_list = copy.copy(orig_index_list)

        if shuffle_index_list:
        while len(index_list) > 0:
            index = index_list.pop()
            add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
                     augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
                     skip_blank=skip_blank, permute=permute)
            if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
                yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels)
                x_list = list()
                y_list = list()

index_list 这里输入不同的索引列表产生训练或验证集数据生成器
while True: 数据生成器函数的标准格式,不是很熟悉的同学可以看我之前fit_generator文章中有讲
x_list = list()用来记录训练/验证集图像数据
y_list = list()用来纪录标签图像数据

if patch_shape:
    index_list = create_patch_index_list(orig_index_list, data_file.root.data.shape[-3:], patch_shape,
                                         patch_overlap, patch_start_offset)
    index_list = copy.copy(orig_index_list)


if shuffle_index_list:
while len(index_list) > 0:
    index = index_list.pop()
    add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
             augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
             skip_blank=skip_blank, permute=permute)

然后执行index = index_list.pop(),pop() 函数用于移除列表中的一个元素(默认最后一个元素),并且返回该元素的值。
因为while len(index_list) > 0,所以我们的列表长度会不停的减少直至为零。
我们一共有24组训练数据,所以从最后一组[23]的最后一个batch起始点[104 104 104]开始读入patch数据,执行add_data来产生patch的数据和标签,执行完后再执行最后一组[23]的倒数第二个batch起始点。。。以此类推,知道读完27*24个patch的数据

if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
    yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels)
    x_list = list()
    y_list = list()



def create_patch_index_list(index_list, image_shape, patch_shape, patch_overlap, patch_start_offset=None):
    patch_index = list()
    for index in index_list:
        if patch_start_offset is not None:
            random_start_offset = np.negative(get_random_nd_index(patch_start_offset))
            patches = compute_patch_indices(image_shape, patch_shape, overlap=patch_overlap, start=random_start_offset)
            patches = compute_patch_indices(image_shape, patch_shape, overlap=patch_overlap)
        patch_index.extend(itertools.product([index], patches))
    return patch_index

for index in index_list对索引列表进行循环,每个索引执行compute_patch_indices函数

patch_index.extend(itertools.product([index], patches))  

extend() 函数用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)。

itertools.product(iterables, repeat=1)
大致相当于生成器表达式中的嵌套循环。例如, product(A, B) 和 ((x,y) for x in A for y in B) 返回结果一样。
要计算可迭代对象自身的笛卡尔积,将可选参数 repeat 设定为要重复的次数。例如,product(A, repeat=4) 和 product(A, A, A, A) 是一样的

sample_list = list(range(30))
training = sample_list[:24]
testing = sample_list[24:]

import itertools
patch_index = list()
for index in training:
  patch_index.extend(itertools.product([index], patches))
# 输出
[(0, array([-24, -24, -24])), (0, array([-24, -24,  40])), (0, array([-24, -24, 104])), (0, array([-24,  40, -24])), (0, array([-24,  40,  40])), (0, array([-24,  40, 104])), (0, array([-24, 104, -24])), (0, array([-24, 104,  40])), (0, array([-24, 104, 104])), (0, array([ 40, -24, -24])), (0, array([ 40, -24,  40])), (0, array([ 40, -24, 104])), (0, array([ 40,  40, -24])), (0, array([40, 40, 40])), (0, array([ 40,  40, 104])), (0, array([ 40, 104, -24])), (0, array([ 40, 104,  40])), (0, array([ 40, 104, 104])), (0, array([104, -24, -24])), (0, array([104, -24,  40])), (0, array([104, -24, 104])), (0, array([104,  40, -24])), (0, array([104,  40,  40])), (0, array([104,  40, 104])), (0, array([104, 104, -24])), (0, array([104, 104,  40])), (0, array([104, 104, 104])), (1, array([-24, -24, -24])), (1, array([-24, -24,  40])), (1, array([-24, -24, 104])), (1, array([-24,  40, -24])), (1, array([-24,  40,  40])), (1, array([-24,  40, 104])), (1, array([-24, 104, -24])), (1, array([-24, 104,  40])), (1, array([-24, 104, 104])), (1, array([ 40, -24, -24])), (1, array([ 40, -24,  40])), (1, array([ 40, -24, 104])), (1, array([ 40,  40, -24])), (1, array([40, 40, 40])), (1, array([ 40,  40, 104])), (1, array([ 40, 104, -24])), (1, array([ 40, 104,  40])), (1, array([ 40, 104, 104])), (1, array([104, -24, -24])), (1, array([104, -24,  40])), (1, array([104, -24, 104])), (1, array([104,  40, -24])), (1, array([104,  40,  40])), (1, array([104,  40, 104])), (1, array([104, 104, -24])), (1, array([104, 104,  40])), (1, array([104, 104, 104])), (2, array([-24, -24, -24])), (2, array([-24, -24,  40])), (2, array([-24, -24, 104])), (2, array([-24,  40, -24])), (2, array([-24,  40,  40])), (2, array([-24,  40, 104])), (2, array([-24, 104, -24])), (2, array([-24, 104,  40])), (2, array([-24, 104, 104])), (2, array([ 40, -24, -24])), (2, array([ 40, -24,  40])), (2, array([ 40, -24, 104])), (2, array([ 40,  40, -24])), (2, array([40, 40, 40])), (2, array([ 40,  40, 104])), (2, array([ 40, 104, -24])), (2, array([ 40, 104,  40])), (2, array([ 40, 104, 104])), (2, array([104, -24, -24])), (2, array([104, -24,  40])), (2, array([104, -24, 104])), (2, array([104,  40, -24])), (2, array([104,  40,  40])), (2, array([104,  40, 104])), (2, array([104, 104, -24])), (2, array([104, 104,  40])), (2, array([104, 104, 104])), (3, array([-24, -24, -24])), (3, array([-24, -24,  40])), (3, array([-24, -24, 104])), (3, array([-24,  40, -24])), (3, array([-24,  40,  40])), (3, array([-24,  40, 104])), (3, array([-24, 104, -24])), (3, array([-24, 104,  40])), (3, array([-24, 104, 104])), (3, array([ 40, -24, -24])), (3, array([ 40, -24,  40])), (3, array([ 40, -24, 104])), (3, array([ 40,  40, -24])), (3, array([40, 40, 40])), (3, array([ 40,  40, 104])), (3, array([ 40, 104, -24])), (3, array([ 40, 104,  40])), (3, array([ 40, 104, 104])), (3, array([104, -24, -24])), (3, array([104, -24,  40])), (3, array([104, -24, 104])), (3, array([104,  40, -24])), (3, array([104,  40,  40])), (3, array([104,  40, 104])), (3, array([104, 104, -24])), (3, array([104, 104,  40])), (3, array([104, 104, 104])), (4, array([-24, -24, -24])), (4, array([-24, -24,  40])), (4, array([-24, -24, 104])), (4, array([-24,  40, -24])), (4, array([-24,  40,  40])), (4, array([-24,  40, 104])), (4, array([-24, 104, -24])), (4, array([-24, 104,  40])), (4, array([-24, 104, 104])), (4, array([ 40, -24, -24])), (4, array([ 40, -24,  40])), (4, array([ 40, -24, 104])), (4, array([ 40,  40, -24])), (4, array([40, 40, 40])), (4, array([ 40,  40, 104])), (4, array([ 40, 104, -24])), (4, array([ 40, 104,  40])), (4, array([ 40, 104, 104])), (4, array([104, -24, -24])), (4, array([104, -24,  40])), (4, array([104, -24, 104])), (4, array([104,  40, -24])), (4, array([104,  40,  40])), (4, array([104,  40, 104])), (4, array([104, 104, -24])), (4, array([104, 104,  40])), (4, array([104, 104, 104])), (5, array([-24, -24, -24])), (5, array([-24, -24,  40])), (5, array([-24, -24, 104])), (5, array([-24,  40, -24])), (5, array([-24,  40,  40])), (5, array([-24,  40, 104])), (5, array([-24, 104, -24])), (5, array([-24, 104,  40])), (5, array([-24, 104, 104])), (5, array([ 40, -24, -24])), (5, array([ 40, -24,  40])), (5, array([ 40, -24, 104])), (5, array([ 40,  40, -24])), (5, array([40, 40, 40])), (5, array([ 40,  40, 104])), (5, array([ 40, 104, -24])), (5, array([ 40, 104,  40])), (5, array([ 40, 104, 104])), (5, array([104, -24, -24])), (5, array([104, -24,  40])), (5, array([104, -24, 104])), (5, array([104,  40, -24])), (5, array([104,  40,  40])), (5, array([104,  40, 104])), (5, array([104, 104, -24])), (5, array([104, 104,  40])), (5, array([104, 104, 104])), (6, array([-24, -24, -24])), (6, array([-24, -24,  40])), (6, array([-24, -24, 104])), (6, array([-24,  40, -24])), (6, array([-24,  40,  40])), (6, array([-24,  40, 104])), (6, array([-24, 104, -24])), (6, array([-24, 104,  40])), (6, array([-24, 104, 104])), (6, array([ 40, -24, -24])), (6, array([ 40, -24,  40])), (6, array([ 40, -24, 104])), (6, array([ 40,  40, -24])), (6, array([40, 40, 40])), (6, array([ 40,  40, 104])), (6, array([ 40, 104, -24])), (6, array([ 40, 104,  40])), (6, array([ 40, 104, 104])), (6, array([104, -24, -24])), (6, array([104, -24,  40])), (6, array([104, -24, 104])), (6, array([104,  40, -24])), (6, array([104,  40,  40])), (6, array([104,  40, 104])), (6, array([104, 104, -24])), (6, array([104, 104,  40])), (6, array([104, 104, 104])), (7, array([-24, -24, -24])), (7, array([-24, -24,  40])), (7, array([-24, -24, 104])), (7, array([-24,  40, -24])), (7, array([-24,  40,  40])), (7, array([-24,  40, 104])), (7, array([-24, 104, -24])), (7, array([-24, 104,  40])), (7, array([-24, 104, 104])), (7, array([ 40, -24, -24])), (7, array([ 40, -24,  40])), (7, array([ 40, -24, 104])), (7, array([ 40,  40, -24])), (7, array([40, 40, 40])), (7, array([ 40,  40, 104])), (7, array([ 40, 104, -24])), (7, array([ 40, 104,  40])), (7, array([ 40, 104, 104])), (7, array([104, -24, -24])), (7, array([104, -24,  40])), (7, array([104, -24, 104])), (7, array([104,  40, -24])), (7, array([104,  40,  40])), (7, array([104,  40, 104])), (7, array([104, 104, -24])), (7, array([104, 104,  40])), (7, array([104, 104, 104])), (8, array([-24, -24, -24])), (8, array([-24, -24,  40])), (8, array([-24, -24, 104])), (8, array([-24,  40, -24])), (8, array([-24,  40,  40])), (8, array([-24,  40, 104])), (8, array([-24, 104, -24])), (8, array([-24, 104,  40])), (8, array([-24, 104, 104])), (8, array([ 40, -24, -24])), (8, array([ 40, -24,  40])), (8, array([ 40, -24, 104])), (8, array([ 40,  40, -24])), (8, array([40, 40, 40])), (8, array([ 40,  40, 104])), (8, array([ 40, 104, -24])), (8, array([ 40, 104,  40])), (8, array([ 40, 104, 104])), (8, array([104, -24, -24])), (8, array([104, -24,  40])), (8, array([104, -24, 104])), (8, array([104,  40, -24])), (8, array([104,  40,  40])), (8, array([104,  40, 104])), (8, array([104, 104, -24])), (8, array([104, 104,  40])), (8, array([104, 104, 104])), (9, array([-24, -24, -24])), (9, array([-24, -24,  40])), (9, array([-24, -24, 104])), (9, array([-24,  40, -24])), (9, array([-24,  40,  40])), (9, array([-24,  40, 104])), (9, array([-24, 104, -24])), (9, array([-24, 104,  40])), (9, array([-24, 104, 104])), (9, array([ 40, -24, -24])), (9, array([ 40, -24,  40])), (9, array([ 40, -24, 104])), (9, array([ 40,  40, -24])), (9, array([40, 40, 40])), (9, array([ 40,  40, 104])), (9, array([ 40, 104, -24])), (9, array([ 40, 104,  40])), (9, array([ 40, 104, 104])), (9, array([104, -24, -24])), (9, array([104, -24,  40])), (9, array([104, -24, 104])), (9, array([104,  40, -24])), (9, array([104,  40,  40])), (9, array([104,  40, 104])), (9, array([104, 104, -24])), (9, array([104, 104,  40])), (9, array([104, 104, 104])), (10, array([-24, -24, -24])), (10, array([-24, -24,  40])), (10, array([-24, -24, 104])), (10, array([-24,  40, -24])), (10, array([-24,  40,  40])), (10, array([-24,  40, 104])), (10, array([-24, 104, -24])), (10, array([-24, 104,  40])), (10, array([-24, 104, 104])), (10, array([ 40, -24, -24])), (10, array([ 40, -24,  40])), (10, array([ 40, -24, 104])), (10, array([ 40,  40, -24])), (10, array([40, 40, 40])), (10, array([ 40,  40, 104])), (10, array([ 40, 104, -24])), (10, array([ 40, 104,  40])), (10, array([ 40, 104, 104])), (10, array([104, -24, -24])), (10, array([104, -24,  40])), (10, array([104, -24, 104])), (10, array([104,  40, -24])), (10, array([104,  40,  40])), (10, array([104,  40, 104])), (10, array([104, 104, -24])), (10, array([104, 104,  40])), (10, array([104, 104, 104])), (11, array([-24, -24, -24])), (11, array([-24, -24,  40])), (11, array([-24, -24, 104])), (11, array([-24,  40, -24])), (11, array([-24,  40,  40])), (11, array([-24,  40, 104])), (11, array([-24, 104, -24])), (11, array([-24, 104,  40])), (11, array([-24, 104, 104])), (11, array([ 40, -24, -24])), (11, array([ 40, -24,  40])), (11, array([ 40, -24, 104])), (11, array([ 40,  40, -24])), (11, array([40, 40, 40])), (11, array([ 40,  40, 104])), (11, array([ 40, 104, -24])), (11, array([ 40, 104,  40])), (11, array([ 40, 104, 104])), (11, array([104, -24, -24])), (11, array([104, -24,  40])), (11, array([104, -24, 104])), (11, array([104,  40, -24])), (11, array([104,  40,  40])), (11, array([104,  40, 104])), (11, array([104, 104, -24])), (11, array([104, 104,  40])), (11, array([104, 104, 104])), (12, array([-24, -24, -24])), (12, array([-24, -24,  40])), (12, array([-24, -24, 104])), (12, array([-24,  40, -24])), (12, array([-24,  40,  40])), (12, array([-24,  40, 104])), (12, array([-24, 104, -24])), (12, array([-24, 104,  40])), (12, array([-24, 104, 104])), (12, array([ 40, -24, -24])), (12, array([ 40, -24,  40])), (12, array([ 40, -24, 104])), (12, array([ 40,  40, -24])), (12, array([40, 40, 40])), (12, array([ 40,  40, 104])), (12, array([ 40, 104, -24])), (12, array([ 40, 104,  40])), (12, array([ 40, 104, 104])), (12, array([104, -24, -24])), (12, array([104, -24,  40])), (12, array([104, -24, 104])), (12, array([104,  40, -24])), (12, array([104,  40,  40])), (12, array([104,  40, 104])), (12, array([104, 104, -24])), (12, array([104, 104,  40])), (12, array([104, 104, 104])), (13, array([-24, -24, -24])), (13, array([-24, -24,  40])), (13, array([-24, -24, 104])), (13, array([-24,  40, -24])), (13, array([-24,  40,  40])), (13, array([-24,  40, 104])), (13, array([-24, 104, -24])), (13, array([-24, 104,  40])), (13, array([-24, 104, 104])), (13, array([ 40, -24, -24])), (13, array([ 40, -24,  40])), (13, array([ 40, -24, 104])), (13, array([ 40,  40, -24])), (13, array([40, 40, 40])), (13, array([ 40,  40, 104])), (13, array([ 40, 104, -24])), (13, array([ 40, 104,  40])), (13, array([ 40, 104, 104])), (13, array([104, -24, -24])), (13, array([104, -24,  40])), (13, array([104, -24, 104])), (13, array([104,  40, -24])), (13, array([104,  40,  40])), (13, array([104,  40, 104])), (13, array([104, 104, -24])), (13, array([104, 104,  40])), (13, array([104, 104, 104])), (14, array([-24, -24, -24])), (14, array([-24, -24,  40])), (14, array([-24, -24, 104])), (14, array([-24,  40, -24])), (14, array([-24,  40,  40])), (14, array([-24,  40, 104])), (14, array([-24, 104, -24])), (14, array([-24, 104,  40])), (14, array([-24, 104, 104])), (14, array([ 40, -24, -24])), (14, array([ 40, -24,  40])), (14, array([ 40, -24, 104])), (14, array([ 40,  40, -24])), (14, array([40, 40, 40])), (14, array([ 40,  40, 104])), (14, array([ 40, 104, -24])), (14, array([ 40, 104,  40])), (14, array([ 40, 104, 104])), (14, array([104, -24, -24])), (14, array([104, -24,  40])), (14, array([104, -24, 104])), (14, array([104,  40, -24])), (14, array([104,  40,  40])), (14, array([104,  40, 104])), (14, array([104, 104, -24])), (14, array([104, 104,  40])), (14, array([104, 104, 104])), (15, array([-24, -24, -24])), (15, array([-24, -24,  40])), (15, array([-24, -24, 104])), (15, array([-24,  40, -24])), (15, array([-24,  40,  40])), (15, array([-24,  40, 104])), (15, array([-24, 104, -24])), (15, array([-24, 104,  40])), (15, array([-24, 104, 104])), (15, array([ 40, -24, -24])), (15, array([ 40, -24,  40])), (15, array([ 40, -24, 104])), (15, array([ 40,  40, -24])), (15, array([40, 40, 40])), (15, array([ 40,  40, 104])), (15, array([ 40, 104, -24])), (15, array([ 40, 104,  40])), (15, array([ 40, 104, 104])), (15, array([104, -24, -24])), (15, array([104, -24,  40])), (15, array([104, -24, 104])), (15, array([104,  40, -24])), (15, array([104,  40,  40])), (15, array([104,  40, 104])), (15, array([104, 104, -24])), (15, array([104, 104,  40])), (15, array([104, 104, 104])), (16, array([-24, -24, -24])), (16, array([-24, -24,  40])), (16, array([-24, -24, 104])), (16, array([-24,  40, -24])), (16, array([-24,  40,  40])), (16, array([-24,  40, 104])), (16, array([-24, 104, -24])), (16, array([-24, 104,  40])), (16, array([-24, 104, 104])), (16, array([ 40, -24, -24])), (16, array([ 40, -24,  40])), (16, array([ 40, -24, 104])), (16, array([ 40,  40, -24])), (16, array([40, 40, 40])), (16, array([ 40,  40, 104])), (16, array([ 40, 104, -24])), (16, array([ 40, 104,  40])), (16, array([ 40, 104, 104])), (16, array([104, -24, -24])), (16, array([104, -24,  40])), (16, array([104, -24, 104])), (16, array([104,  40, -24])), (16, array([104,  40,  40])), (16, array([104,  40, 104])), (16, array([104, 104, -24])), (16, array([104, 104,  40])), (16, array([104, 104, 104])), (17, array([-24, -24, -24])), (17, array([-24, -24,  40])), (17, array([-24, -24, 104])), (17, array([-24,  40, -24])), (17, array([-24,  40,  40])), (17, array([-24,  40, 104])), (17, array([-24, 104, -24])), (17, array([-24, 104,  40])), (17, array([-24, 104, 104])), (17, array([ 40, -24, -24])), (17, array([ 40, -24,  40])), (17, array([ 40, -24, 104])), (17, array([ 40,  40, -24])), (17, array([40, 40, 40])), (17, array([ 40,  40, 104])), (17, array([ 40, 104, -24])), (17, array([ 40, 104,  40])), (17, array([ 40, 104, 104])), (17, array([104, -24, -24])), (17, array([104, -24,  40])), (17, array([104, -24, 104])), (17, array([104,  40, -24])), (17, array([104,  40,  40])), (17, array([104,  40, 104])), (17, array([104, 104, -24])), (17, array([104, 104,  40])), (17, array([104, 104, 104])), (18, array([-24, -24, -24])), (18, array([-24, -24,  40])), (18, array([-24, -24, 104])), (18, array([-24,  40, -24])), (18, array([-24,  40,  40])), (18, array([-24,  40, 104])), (18, array([-24, 104, -24])), (18, array([-24, 104,  40])), (18, array([-24, 104, 104])), (18, array([ 40, -24, -24])), (18, array([ 40, -24,  40])), (18, array([ 40, -24, 104])), (18, array([ 40,  40, -24])), (18, array([40, 40, 40])), (18, array([ 40,  40, 104])), (18, array([ 40, 104, -24])), (18, array([ 40, 104,  40])), (18, array([ 40, 104, 104])), (18, array([104, -24, -24])), (18, array([104, -24,  40])), (18, array([104, -24, 104])), (18, array([104,  40, -24])), (18, array([104,  40,  40])), (18, array([104,  40, 104])), (18, array([104, 104, -24])), (18, array([104, 104,  40])), (18, array([104, 104, 104])), (19, array([-24, -24, -24])), (19, array([-24, -24,  40])), (19, array([-24, -24, 104])), (19, array([-24,  40, -24])), (19, array([-24,  40,  40])), (19, array([-24,  40, 104])), (19, array([-24, 104, -24])), (19, array([-24, 104,  40])), (19, array([-24, 104, 104])), (19, array([ 40, -24, -24])), (19, array([ 40, -24,  40])), (19, array([ 40, -24, 104])), (19, array([ 40,  40, -24])), (19, array([40, 40, 40])), (19, array([ 40,  40, 104])), (19, array([ 40, 104, -24])), (19, array([ 40, 104,  40])), (19, array([ 40, 104, 104])), (19, array([104, -24, -24])), (19, array([104, -24,  40])), (19, array([104, -24, 104])), (19, array([104,  40, -24])), (19, array([104,  40,  40])), (19, array([104,  40, 104])), (19, array([104, 104, -24])), (19, array([104, 104,  40])), (19, array([104, 104, 104])), (20, array([-24, -24, -24])), (20, array([-24, -24,  40])), (20, array([-24, -24, 104])), (20, array([-24,  40, -24])), (20, array([-24,  40,  40])), (20, array([-24,  40, 104])), (20, array([-24, 104, -24])), (20, array([-24, 104,  40])), (20, array([-24, 104, 104])), (20, array([ 40, -24, -24])), (20, array([ 40, -24,  40])), (20, array([ 40, -24, 104])), (20, array([ 40,  40, -24])), (20, array([40, 40, 40])), (20, array([ 40,  40, 104])), (20, array([ 40, 104, -24])), (20, array([ 40, 104,  40])), (20, array([ 40, 104, 104])), (20, array([104, -24, -24])), (20, array([104, -24,  40])), (20, array([104, -24, 104])), (20, array([104,  40, -24])), (20, array([104,  40,  40])), (20, array([104,  40, 104])), (20, array([104, 104, -24])), (20, array([104, 104,  40])), (20, array([104, 104, 104])), (21, array([-24, -24, -24])), (21, array([-24, -24,  40])), (21, array([-24, -24, 104])), (21, array([-24,  40, -24])), (21, array([-24,  40,  40])), (21, array([-24,  40, 104])), (21, array([-24, 104, -24])), (21, array([-24, 104,  40])), (21, array([-24, 104, 104])), (21, array([ 40, -24, -24])), (21, array([ 40, -24,  40])), (21, array([ 40, -24, 104])), (21, array([ 40,  40, -24])), (21, array([40, 40, 40])), (21, array([ 40,  40, 104])), (21, array([ 40, 104, -24])), (21, array([ 40, 104,  40])), (21, array([ 40, 104, 104])), (21, array([104, -24, -24])), (21, array([104, -24,  40])), (21, array([104, -24, 104])), (21, array([104,  40, -24])), (21, array([104,  40,  40])), (21, array([104,  40, 104])), (21, array([104, 104, -24])), (21, array([104, 104,  40])), (21, array([104, 104, 104])), (22, array([-24, -24, -24])), (22, array([-24, -24,  40])), (22, array([-24, -24, 104])), (22, array([-24,  40, -24])), (22, array([-24,  40,  40])), (22, array([-24,  40, 104])), (22, array([-24, 104, -24])), (22, array([-24, 104,  40])), (22, array([-24, 104, 104])), (22, array([ 40, -24, -24])), (22, array([ 40, -24,  40])), (22, array([ 40, -24, 104])), (22, array([ 40,  40, -24])), (22, array([40, 40, 40])), (22, array([ 40,  40, 104])), (22, array([ 40, 104, -24])), (22, array([ 40, 104,  40])), (22, array([ 40, 104, 104])), (22, array([104, -24, -24])), (22, array([104, -24,  40])), (22, array([104, -24, 104])), (22, array([104,  40, -24])), (22, array([104,  40,  40])), (22, array([104,  40, 104])), (22, array([104, 104, -24])), (22, array([104, 104,  40])), (22, array([104, 104, 104])), (23, array([-24, -24, -24])), (23, array([-24, -24,  40])), (23, array([-24, -24, 104])), (23, array([-24,  40, -24])), (23, array([-24,  40,  40])), (23, array([-24,  40, 104])), (23, array([-24, 104, -24])), (23, array([-24, 104,  40])), (23, array([-24, 104, 104])), (23, array([ 40, -24, -24])), (23, array([ 40, -24,  40])), (23, array([ 40, -24, 104])), (23, array([ 40,  40, -24])), (23, array([40, 40, 40])), (23, array([ 40,  40, 104])), (23, array([ 40, 104, -24])), (23, array([ 40, 104,  40])), (23, array([ 40, 104, 104])), (23, array([104, -24, -24])), (23, array([104, -24,  40])), (23, array([104, -24, 104])), (23, array([104,  40, -24])), (23, array([104,  40,  40])), (23, array([104,  40, 104])), (23, array([104, 104, -24])), (23, array([104, 104,  40])), (23, array([104, 104, 104]))]



def compute_patch_indices(image_shape, patch_size, overlap, start=None):
    if isinstance(overlap, int):
        overlap = np.asarray([overlap] * len(image_shape))
    if start is None:
        n_patches = np.ceil(image_shape / (patch_size - overlap))
        overflow = (patch_size - overlap) * n_patches - image_shape + overlap
        start = -np.ceil(overflow/2)
    elif isinstance(start, int):
        start = np.asarray([start] * len(image_shape))
    stop = image_shape + start
    step = patch_size - overlap
    return get_set_of_patch_indices(start, stop, step)

start = None

if isinstance(overlap, int):
        overlap = np.asarray([overlap] * len(image_shape))
print((patch_size - overlap))
[0 0 0]
[64 64 64]

isinstance() 函数来判断一个对象是否是一个已知的类型

[0, 0, 0]


if start is None:
    n_patches = np.ceil(image_shape / (patch_size ))
    overflow = (patch_size - overlap) * n_patches - image_shape
    start = -np.ceil(overflow/2)  

ceil() 函数返回数字的上入整数。

n_patches = np.ceil(image_shape / (patch_size - overlap))
[3. 3. 3.]#144/64向上取整


overflow = (patch_size - overlap) * n_patches - image_shape
[48. 48. 48.]

用之前得到的n_patches乘以patch_size 可以算出我们超出了原图像的范围

start = -np.ceil(overflow/2)
[-24. -24. -24.]


stop = image_shape = data_file.root.data.shape[-3:]
step = patch_size #(64,64,64)

image_shape = data_file.root.data.shape[-3:]
(144, 144, 144)

最后执行get_set_of_patch_indices(start, stop, step)来得到最终的索引列表,函数详解见下文

def get_set_of_patch_indices(start, stop, step):
    return np.asarray(np.mgrid[start[0]:stop[0]:step[0], start[1]:stop[1]:step[1],
                               start[2]:stop[2]:step[2]].reshape(3, -1).T, dtype=np.int)


table_inx=np.mgrid[start[0]:stop[0]:step[0], start[1]:stop[1]:step[1],start[2]:stop[2]:step[2]]
[[[[-24. -24. -24.]
   [-24. -24. -24.]
   [-24. -24. -24.]]

  [[ 40.  40.  40.]
   [ 40.  40.  40.]
   [ 40.  40.  40.]]

  [[104. 104. 104.]
   [104. 104. 104.]
   [104. 104. 104.]]]

 [[[-24. -24. -24.]
   [ 40.  40.  40.]
   [104. 104. 104.]]

  [[-24. -24. -24.]
   [ 40.  40.  40.]
   [104. 104. 104.]]

  [[-24. -24. -24.]
   [ 40.  40.  40.]
   [104. 104. 104.]]]

 [[[-24.  40. 104.]
   [-24.  40. 104.]
   [-24.  40. 104.]]

  [[-24.  40. 104.]
   [-24.  40. 104.]
   [-24.  40. 104.]]

  [[-24.  40. 104.]
   [-24.  40. 104.]
   [-24.  40. 104.]]]]
a=(np.asarray(np.mgrid[start[0]:stop[0]:step[0], start[1]:stop[1]:step[1],
                               start[2]:stop[2]:step[2]].reshape(3, -1).T, dtype=np.int))
(3, 3, 3, 3)
[[-24 -24 -24]
 [-24 -24  40]
 [-24 -24 104]
 [-24  40 -24]
 [-24  40  40]
 [-24  40 104]
 [-24 104 -24]
 [-24 104  40]
 [-24 104 104]
 [ 40 -24 -24]
 [ 40 -24  40]
 [ 40 -24 104]
 [ 40  40 -24]
 [ 40  40  40]
 [ 40  40 104]
 [ 40 104 -24]
 [ 40 104  40]
 [ 40 104 104]
 [104 -24 -24]
 [104 -24  40]
 [104 -24 104]
 [104  40 -24]
 [104  40  40]
 [104  40 104]
 [104 104 -24]
 [104 104  40]
 [104 104 104]]
(27, 3)



def add_data(x_list, y_list, data_file, index, augment=False, augment_flip=False, augment_distortion_factor=0.25,
             patch_shape=False, skip_blank=True, permute=False):
    Adds data from the data file to the given lists of feature and target data
    :param skip_blank: Data will not be added if the truth vector is all zeros (default is True).
    :param patch_shape: Shape of the patch to add to the data lists. If None, the whole image will be added.
    :param x_list: list of data to which data from the data_file will be appended.
    :param y_list: list of data to which the target data from the data_file will be appended.
    :param data_file: hdf5 data file.
    :param index: index of the data file from which to extract the data.
    :param augment: if True, data will be augmented according to the other augmentation parameters (augment_flip and
    :param augment_flip: if True and augment is True, then the data will be randomly flipped along the x, y and z axis
    :param augment_distortion_factor: if augment is True, this determines the standard deviation from the original
    that the data will be distorted (in a stretching or shrinking fashion). Set to None, False, or 0 to prevent the
    augmentation from distorting the data in this way.
    :param permute: will randomly permute the data (data must be 3D cube)
    data, truth = get_data_from_file(data_file, index, patch_shape=patch_shape)
    if augment:
        if patch_shape is not None:
            affine = data_file.root.affine[index[0]]
            affine = data_file.root.affine[index]
        data, truth = augment_data(data, truth, affine, flip=augment_flip, scale_deviation=augment_distortion_factor)

    if permute:
        if data.shape[-3] != data.shape[-2] or data.shape[-2] != data.shape[-1]:
            raise ValueError("To utilize permutations, data array must be in 3D cube shape with all dimensions having "
                             "the same length.")
        data, truth = random_permutation_x_y(data, truth[np.newaxis])
        truth = truth[np.newaxis]

    if not skip_blank or np.any(truth != 0):


data, truth = get_data_from_file(data_file, index, patch_shape=patch_shape)
if not skip_blank or np.any(truth != 0):



def get_data_from_file(data_file, index, patch_shape=None):
    if patch_shape:
        index, patch_index = index
        data, truth = get_data_from_file(data_file, index, patch_shape=None)
        x = get_patch_from_3d_data(data, patch_shape, patch_index)
        y = get_patch_from_3d_data(truth, patch_shape, patch_index)
        x, y = data_file.root.data[index], data_file.root.truth[index, 0]
    return x, y

data, truth = get_data_from_file(data_file, index, patch_shape=None)

index1, patch_index = index
[104 104 104]

x, y = data_file.root.data[index], data_file.root.truth[index, 0]

(30, 4, 144, 144, 144)
(30, 1, 144, 144, 144)

x, y = data_file.root.data[index], data_file.root.truth[index, 0]
而data_file.root.truth也是五维的,但其实我们每组数据只有一张标签图像,所以为data_file.root.truth[index, 0]

x, y = data_file.root.data[index1], data_file.root.truth[index1, 0]
[[[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]


  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]

 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]


  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]

 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]


  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]

 [[[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]


  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]

  [[0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]
   [0. 0. 0. ... 0. 0. 0.]]]]
(4, 144, 144, 144)
[[[0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]]

 [[0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]]

 [[0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]]


 [[0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]]

 [[0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]]

 [[0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]
  [0 0 0 ... 0 0 0]]]
(144, 144, 144)


def get_patch_from_3d_data(data, patch_shape, patch_index):
    Returns a patch from a numpy array.
    :param data: numpy array from which to get the patch.
    :param patch_shape: shape/size of the patch.
    :param patch_index: corner index of the patch.
    :return: numpy array take from the data with the patch shape specified.
    patch_index = np.asarray(patch_index, dtype=np.int16)
    patch_shape = np.asarray(patch_shape)
    image_shape = data.shape[-3:]
    if np.any(patch_index < 0) or np.any((patch_index + patch_shape) > image_shape):
        data, patch_index = fix_out_of_bound_patch_attempt(data, patch_shape, patch_index)
    return data[..., patch_index[0]:patch_index[0]+patch_shape[0], patch_index[1]:patch_index[1]+patch_shape[1],



patch_index = np.asarray(patch_index, dtype=np.int16)
patch_shape = np.asarray((64,64,64))
image_shape = data.shape[-3:]
[104 104 104]
[64 64 64]
(144, 144, 144)


if np.any(patch_index < 0) or np.any((patch_index + patch_shape) > image_shape):
    data, patch_index = fix_out_of_bound_patch_attempt(data, patch_shape, patch_index)


return data[..., patch_index[0]:patch_index[0]+patch_shape[0], patch_index[1]:patch_index[1]+patch_shape[1],


a=data[..., patch_index[0]:patch_index[0]+patch_shape[0], patch_index[1]:patch_index[1]+patch_shape[1],
[[[[ 974.64435  967.2424   959.2515  ...    0.         0.
       0.     ]
   [ 990.2376   982.45245  976.0181  ...    0.         0.
       0.     ]
   [1014.01    1011.8526  1012.1444  ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[ 979.8702   969.1078   957.12286 ...    0.         0.
       0.     ]
   [1019.1183  1006.22644  993.1565  ...    0.         0.
       0.     ]
   [1042.1243  1035.1311  1029.3956  ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[ 999.7032   988.7418   977.0992  ...    0.         0.
       0.     ]
   [1035.4014  1024.9207  1013.7255  ...    0.         0.
       0.     ]
   [1039.5989  1035.5583  1031.2466  ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]


  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]]

 [[[ 144.65683  165.30124  180.96454 ...    0.         0.
       0.     ]
   [ 153.01872  175.63034  187.93913 ...    0.         0.
       0.     ]
   [ 162.61511  174.65828  179.1287  ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[ 157.46268  171.40187  180.69766 ...    0.         0.
       0.     ]
   [ 168.83534  178.6118   181.83922 ...    0.         0.
       0.     ]
   [ 177.60065  184.03499  185.06255 ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[ 166.20416  171.03703  177.93298 ...    0.         0.
       0.     ]
   [ 176.08592  174.61232  177.79686 ...    0.         0.
       0.     ]
   [ 177.65994  175.62888  178.42007 ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]


  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]]

 [[[ 458.       463.       467.      ...    0.         0.
       0.     ]
   [ 449.       451.       452.      ...    0.         0.
       0.     ]
   [ 432.       437.       443.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[ 467.       471.       476.      ...    0.         0.
       0.     ]
   [ 447.       448.       449.      ...    0.         0.
       0.     ]
   [ 438.       440.       443.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[ 450.       453.       455.      ...    0.         0.
       0.     ]
   [ 432.       432.       433.      ...    0.         0.
       0.     ]
   [ 440.       435.       430.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]


  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]]

 [[[ 398.2283   405.59125  413.5954  ...    0.         0.
       0.     ]
   [ 404.69037  413.0081   421.04782 ...    0.         0.
       0.     ]
   [ 392.00458  400.79367  409.79398 ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[ 385.11923  391.67032  398.71704 ...    0.         0.
       0.     ]
   [ 393.11557  400.79193  409.13693 ...    0.         0.
       0.     ]
   [ 390.004    398.76068  406.39658 ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[ 379.12677  385.17493  390.83636 ...    0.         0.
       0.     ]
   [ 391.53232  397.18793  402.48724 ...    0.         0.
       0.     ]
   [ 393.9321   399.31738  403.58066 ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]


  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]

  [[   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]
   [   0.         0.         0.      ...    0.         0.
       0.     ]]]]
(4, 64, 64, 64)
def fix_out_of_bound_patch_attempt(data, patch_shape, patch_index, ndim=3):
    Pads the data and alters the patch index so that a patch will be correct.
    :param data:
    :param patch_shape:
    :param patch_index:
    :return: padded data, fixed patch index
    image_shape = data.shape[-ndim:]
    pad_before = np.abs((patch_index < 0) * patch_index)
    pad_after = np.abs(((patch_index + patch_shape) > image_shape) * ((patch_index + patch_shape) - image_shape))
    pad_args = np.stack([pad_before, pad_after], axis=1)
    if pad_args.shape[0] < len(data.shape):
        pad_args = [[0, 0]] * (len(data.shape) - pad_args.shape[0]) + pad_args.tolist()
    data = np.pad(data, pad_args, mode="edge")
    patch_index += pad_before
    return data, patch_index

用image_shape = data.shape[-3:]来获取单张图片大小

pad_before = np.abs((patch_index < 0) * patch_index)
pad_after = np.abs(((patch_index + patch_shape) > image_shape) * ((patch_index + patch_shape) - image_shape))
然后用np.stack([pad_before, pad_after], axis=1)把之前之后的数串起来

pad_args = np.stack([pad_before, pad_after], axis=1)
[[ 0 24]
 [ 0 24]
 [ 0 24]]
if pad_args.shape[0] < len(data.shape):
    pad_args = [[0, 0]] * (len(data.shape) - pad_args.shape[0]) + pad_args.tolist()


pad_args = [[0, 0]] * (len(data.shape) - pad_args.shape[0]) + pad_args.tolist()
[[0, 0], [0, 24], [0, 24], [0, 24]]
data = np.pad(data, pad_args, mode="edge")
patch_index += pad_before
return data, patch_index

patch_index += pad_before的作用是如果开始的index值为负的话,对于填充完的图像数据,现在索引应该从零开始


def convert_data(x_list, y_list, n_labels=1, labels=None):
    x = np.asarray(x_list)
    y = np.asarray(y_list)
    if n_labels == 1:
        y[y > 0] = 1
    elif n_labels > 1:
        y = get_multi_class_labels(y, n_labels=n_labels, labels=labels)
    return x, y

但如果含有多个标签,比如说我们brats的数据有1,2,4三种标签,就要进行额外操作了,get_multi_class_labels将标签图转换为一组二进制numpy数组,shape: (n_samples, n_labels, …)


def get_number_of_steps(n_samples, batch_size):
    if n_samples <= batch_size:
        return n_samples
    elif np.remainder(n_samples, batch_size) == 0:
        return n_samples//batch_size
        return n_samples//batch_size + 1



def get_number_of_patches(data_file, index_list, patch_shape=None, patch_overlap=0, patch_start_offset=None,
    if patch_shape:
        index_list = create_patch_index_list(index_list, data_file.root.data.shape[-3:], patch_shape, patch_overlap,
        count = 0
        for index in index_list:
            x_list = list()
            y_list = list()
            add_data(x_list, y_list, data_file, index, skip_blank=skip_blank, patch_shape=patch_shape)
            if len(x_list) > 0:
                count += 1
        return count
        return len(index_list)


  • 6
  • 8
    觉得还不错? 一键收藏
  • 2


  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
评论 2




当前余额3.43前往充值 >
领取后你会自动成为博主和红包主的粉丝 规则
钱包余额 0


