Keras中的多输入ImageDataGenerator图片生成器

参考keras官网以及https://github.com/Deep-Learning-Person-Re-Identification/
通过重载Keras自带的ImageDataGenerator来实现同时输入多张图片。该代码是为了实现Keras中的TripletNet中的Triplet Loss。
更多部分参考:
http://blog.csdn.net/yjy728/article/details/79570554
http://blog.csdn.net/yjy728/article/details/79569807
代码环境:

  • keras:2.1.2
  • tensorflow:1.4.0
  • python3.6
  • win7
from keras import backend as K
import numpy as np
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator, Iterator
from keras.utils import np_utils
class ImageDataGenerator_Triplet(ImageDataGenerator):
    def flow(self, basepath, batch_size=32, class_num=731, input_size=299,
             train_vali_flag = 'train',
             shuffle=False, seed=None,
             save_to_dir=None, save_prefix='', save_format='png'):
        return NumpyArrayIterator_Triplet(
            self,
            class_num = class_num,
            batch_size=batch_size,
            input_size=input_size,
            train_vali_flag=train_vali_flag,
            basepath=basepath,
            shuffle=shuffle,
            seed=seed)

class NumpyArrayIterator_Triplet(Iterator):
    def __init__(self, image_data_generator, class_num, input_size,
                 train_vali_flag,basepath,
                 batch_size=32, shuffle=False, seed=None):
        self.image_data_generator = image_data_generator
        self.class_num = class_num
        self.input_size = input_size
        self.train_vali_flag = train_vali_flag
        self.data_format = K.image_data_format()
        self.basepath = basepath
        super(NumpyArrayIterator_Triplet, self).__init__(8848, batch_size*3, shuffle, seed)

    def _get_batches_of_transformed_samples(self, index_array):
        batch_x = np.zeros(tuple([len(index_array)] + [self.input_size, self.input_size, 3]), dtype=K.floatx())
        batch_y = np.zeros([len(index_array), 1])
        batch_z = np.zeros([len(index_array), 1])
        for i in range(len(index_array) // 3):
            ka, kb = np.random.randint(low=0, high=self.class_num, size=2)  # 随机生成二位数组
            while ka == kb:
                ka, kb = np.random.randint(low=0, high=self.class_num, size=2)

            if self.train_vali_flag == 'train':
                kc, kd = np.random.choice([1, 2, 4, 5], 2)
            elif self.train_vali_flag == 'test':
                kc, kd = np.random.choice([3, 6], 2)
            else:
                raise('param train_vali_flag must be choosen from train and vali')

            img_achor = Image.open(self.basepath + self.train_vali_flag + '/' + str(ka) + '/' + str(kc) + '.bmp')
            x_anchor = np.array(img_achor.resize([self.input_size, self.input_size]))

            img_pos = Image.open(self.basepath + self.train_vali_flag + '/' + str(ka) + '/' + str(kd) + '.bmp')
            x_pos = np.array(img_pos.resize([self.input_size, self.input_size]))

            img_neg = Image.open(self.basepath + self.train_vali_flag + '/' + str(kb) + '/' + str(kd) + '.bmp')
            x_neg = np.array(img_neg.resize([self.input_size, self.input_size]))

            x_anchor = self.image_data_generator.random_transform(x_anchor.astype(K.floatx()))
            x_anchor = self.image_data_generator.standardize(x_anchor)
            x_pos = self.image_data_generator.random_transform(x_pos.astype(K.floatx()))
            x_pos = self.image_data_generator.standardize(x_pos)
            x_neg = self.image_data_generator.random_transform(x_neg.astype(K.floatx()))
            x_neg = self.image_data_generator.standardize(x_neg)
            batch_x[i] = x_anchor
            batch_x[i + len(index_array) // 3] = x_pos
            batch_x[i + len(index_array) // 3 * 2] = x_neg

            batch_y[i] = ka
            batch_y[i + len(index_array) // 3] = ka
            batch_y[i + len(index_array) // 3 * 2] = kb
        batch_y = np_utils.to_categorical(batch_y, self.class_num)
        #print(batch_x.shape)
        return batch_x, [batch_y, batch_z]

    def next(self):
        """For python 2.x.

        # Returns
            The next batch.
        """
        # Keeps under lock only the mechanism which advances
        # the indexing of each batch.
        with self.lock:
            index_array = next(self.index_generator)
        # The transformation of images is not under thread lock
        # so it can be done in parallel
        return self._get_batches_of_transformed_samples(index_array)

    def __getitem__(self, idx):
        if self.index_array is None:
            self._set_index_array()
        index_array = self.index_array[0: self.batch_size]
        return self._get_batches_of_transformed_samples(index_array)

    def _flow_index(self):
        # Ensure self.batch_index is 0.
        self.reset()
        while 1:
            yield self.index_array[0: self.batch_size]
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值