简介
本文主要介绍一些使用Keras过程中使用频率较高的常用代码段。
表格文件快速生成数据集
from keras.preprocessing.image import ImageDataGenerator
train_gen = ImageDataGenerator(rescale=1/255., validation_split=0.2, horizontal_flip=False, shear_range=0.2, width_shift_range=0.1)
test_gen = ImageDataGenerator(rescale=1/255.)
img_size = (224, 224)
batch_size = 32
train_generator = train_gen.flow_from_dataframe(dataframe=df_train,
directory='data/train',
x_col='file_id',
y_col='accent',
batch_size=batch_size,
class_mode='categorical',
target_size=img_size,
subset='training')
valid_generator = train_gen.flow_from_dataframe(dataframe=df_train,
directory="data/train",
x_col="file_id",
y_col="accent",
batch_size=batch_size,
class_mode="categorical",
target_size=img_size,
subset='validation')
test_generator = test_gen.flow_from_dataframe(dataframe=df_test,
directory = "data/test",
x_col="file_id",
target_size=img_size,
batch_size=1,
shuffle=False,
class_mode=None)
预训练模型使用
from keras.applications.densenet import DenseNet121
from keras.models import Model
from keras.layers import GlobalAveragePooling2D, Input, Dropout, Dense, BatchNormalization
from keras.optimizers import Adam
def build_densenet(input_shape=(224, 224, 3), n_classes=3):
input_layer = Input(shape=input_shape)
densenet121 = DenseNet121(include_top=False, weights='imagenet', input_tensor=input_layer)
x = GlobalAveragePooling2D()(densenet121.output)
x = Dropout(0.5)(x)
x = Dense(n_classes, activation='softmax')(x)
model = Model(input_layer, x)
model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=3e-4), metrics=['accuracy'])
return model
densenet = build_densenet()