代码实现的影像数据的批量训练,fit_generator()就是将原始影像和标签数据打包在一起成一个tuple,然后再喂给模型训练。
import os
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import random
import PIL.Image as img
from skimage.transform import resize
from skimage.io import imread
from sklearn.model_selection import train_test_split
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from keras.optimizers import SGD, Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
import tensorflow as tf
path = "G:/data/"
path_1 = path + "landset-1-3/"
path_2 = path + "landset-4-6/"
bounds = 6 # 加载几个波段的数据
classes = 2
batch_size = 8
epoch=100
def trainGenerator(batch_size=None, train_image_path=None, shape=None, bounds=None):
imageList = os.listdir(train_image_path + "landset-1-3/")
img_generator = np.zeros((batch_size, shape, shape, bounds), np.float32)
label_generator = np.zeros((batch_size, 2))
while(True):
rand = random.randint(0, len(imageList) - batch_size)
for j in range(batch_size):
img1 = imread(train_image_path + "landset-1-3/" + imageList[rand+j])
img2 = imread(train_image_path + "landset-4-6/" + imageList[rand+j])
img_1 = resize(img1, (224, 224, 3), mode='constant', preserve_range=True)
img_2 = resize(img2, (224, 224, 3), mode='constant', preserve_range=True)
img_generator[j,:,:,:3] = img_1
img_generator[j,:,:,3:6] = img_2
temp = os.path.splitext(imageList[rand+j])[0]
if int(temp) < 2500: # 根据图片的索引来判断是类别是 0/1
label_generator[j] = (0,1)
else:
label_generator[j] = (1,0)
yield (img_generator, label_generator)
# 训练数据
train_set = trainGenerator(batch_size=batch_size, train_image_path=path, shape=224, bounds=bounds)
# 测试数据
val_set = trainGenerator(batch_size=batch_size, train_image_path=path, shape=224, bounds=bounds)
train_num = len(os.listdir(path_1))
# 每一个epoch训练多少个batch
steps_per_epoch = int(train_num / batch_size)
history = model1.fit_generator(train_set,
steps_per_epoch = steps_per_epoch,
epochs = epoch,
validation_data = val_set,
validation_steps = 5)
现在检查这段代码是在写的太原始了,至少要单独访问数据集的路径,而不是直接就咋本地文件里访问。