keras使用fit_generator批量训练二分类模型

代码实现的影像数据的批量训练,fit_generator()就是将原始影像和标签数据打包在一起成一个tuple,然后再喂给模型训练。

import os
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import random
import PIL.Image as img
from skimage.transform import resize
from skimage.io import imread
from sklearn.model_selection import train_test_split
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from keras.optimizers import SGD, Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
import tensorflow as tf


path = "G:/data/"
path_1 = path + "landset-1-3/"
path_2 = path + "landset-4-6/"
bounds = 6  # 加载几个波段的数据
classes = 2
batch_size = 8
epoch=100


def trainGenerator(batch_size=None, train_image_path=None, shape=None, bounds=None):
    imageList = os.listdir(train_image_path + "landset-1-3/")
    img_generator = np.zeros((batch_size, shape, shape, bounds), np.float32)
    label_generator = np.zeros((batch_size, 2))
    while(True):
        rand = random.randint(0, len(imageList) - batch_size)
        for j in range(batch_size):
            img1 = imread(train_image_path + "landset-1-3/" + imageList[rand+j])
            img2 = imread(train_image_path + "landset-4-6/" + imageList[rand+j])
            img_1 = resize(img1, (224, 224, 3), mode='constant', preserve_range=True)
            img_2 = resize(img2, (224, 224, 3), mode='constant', preserve_range=True)
            img_generator[j,:,:,:3] = img_1 
            img_generator[j,:,:,3:6] = img_2 
            temp = os.path.splitext(imageList[rand+j])[0]
            if int(temp) < 2500:  # 根据图片的索引来判断是类别是 0/1
                label_generator[j] = (0,1)
            else:
                label_generator[j] = (1,0)
        yield (img_generator, label_generator)


# 训练数据
train_set = trainGenerator(batch_size=batch_size, train_image_path=path, shape=224, bounds=bounds)
# 测试数据
val_set = trainGenerator(batch_size=batch_size, train_image_path=path, shape=224, bounds=bounds)
train_num = len(os.listdir(path_1))
# 每一个epoch训练多少个batch
steps_per_epoch = int(train_num / batch_size)

history = model1.fit_generator(train_set,
                    steps_per_epoch = steps_per_epoch,
                    epochs = epoch,
                    validation_data = val_set,
                    validation_steps = 5)

现在检查这段代码是在写的太原始了,至少要单独访问数据集的路径,而不是直接就咋本地文件里访问。

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值