Tensorflow2.0 使用Keras 迭代器 加载图像分割训练集

8 篇文章 2 订阅
4 篇文章 0 订阅

当遇到大数据时,无法将数据全部加载进内存,需要用到分批次加载,网上的方法很多都是关于分类数据集,记录一下分割数据集使用迭代器进行数据加载的方式。
主要从keras.utils.Sequence 继承后定义一个数据加载器 DataGenerator。
注:本文的代码只展现了关键部分,不是完整代码

定义数据生成器

class DataGenerator(keras.utils.Sequence):

    def __init__(self, data_img, data_mask, batch_size=1, shuffle=True):
        self.batch_size = batch_size
        self.data_img = data_img
        self.data_mask = data_mask
        self.indexes = np.arange(len(self.data_img))
        self.shuffle = shuffle

    def __len__(self):
        # 计算每一个epoch的迭代次数
        return math.ceil(len(self.data_img) / float(self.batch_size))

    def __getitem__(self, index):
        # 生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
        # 生成batch_size个索引
        batch_indexs = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        # 根据索引获取datas集合中的数据
        batch_data_img = [self.data_img[k] for k in batch_indexs]
        batch_data_mask = [self.data_mask[k] for k in batch_indexs]

        # 生成数据
        X, y = self.data_generation(batch_data_img, batch_data_mask)

        return X, y

    def on_epoch_end(self):
        # 在每一次epoch结束是否需要进行一次随机,重新随机一下index
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def data_generation(self, batch_data_img, batch_data_mask):
        images = []
        masks = []

        # 生成数据
        for data_img, data_mask in zip(batch_data_img, batch_data_mask):
            # x_train数据
            image = cv2.imread(data_img,cv2.IMREAD_COLOR)
            image = cv2.resize(image,(256,256))

            image = list(image)
            images.append(image)
            # y_train数据
            mask = cv2.imread(data_mask,cv2.IMREAD_GRAYSCALE)
            mask = cv2.resize(mask, (256,256))
            mask = mask / 255.0
            mask = list(mask)
            masks.append(mask)

        return np.array(images), np.array(masks)


# 读取样本名称,然后根据样本名称去读取数据


train_img = sorted(glob.glob('./trainnsmc/image/*.png'))
train_mask = sorted(glob.glob('./trainnsmc/label/*.png'))
# 数据生成器
training_generator = DataGenerator(train_img, train_mask,batch_size=8)

建立Unet模型,编译模型进行训练

model = unet()
#编译模型
from keras_unet_collection import losses
model.compile(optimizer=tf.keras.optimizers.Adam(lr), loss=losses.dice,
              metrics=[ 'acc',losses.dice_coef])

model.fit(training_generator, epochs=50,  max_queue_size=10)
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值