Keras-常用代码

简介

本文主要介绍一些使用Keras过程中使用频率较高的常用代码段。

表格文件快速生成数据集

from keras.preprocessing.image import ImageDataGenerator

train_gen = ImageDataGenerator(rescale=1/255., validation_split=0.2, horizontal_flip=False, shear_range=0.2, width_shift_range=0.1)
test_gen = ImageDataGenerator(rescale=1/255.) 
img_size = (224, 224)
batch_size = 32
train_generator = train_gen.flow_from_dataframe(dataframe=df_train,
                                               directory='data/train',
                                               x_col='file_id',
                                               y_col='accent',
                                               batch_size=batch_size,
                                               class_mode='categorical',
                                               target_size=img_size, 
                                               subset='training')
valid_generator = train_gen.flow_from_dataframe(dataframe=df_train,
                                                  directory="data/train",
                                                  x_col="file_id",
                                                  y_col="accent",
                                                  batch_size=batch_size,
                                                  class_mode="categorical",    
                                                  target_size=img_size,
                                                  subset='validation')
test_generator = test_gen.flow_from_dataframe(dataframe=df_test,
                                                  directory = "data/test",
                                                  x_col="file_id",
                                                  target_size=img_size,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  class_mode=None)

预训练模型使用

from keras.applications.densenet import DenseNet121
from keras.models import Model
from keras.layers import GlobalAveragePooling2D, Input, Dropout, Dense, BatchNormalization
from keras.optimizers import Adam

def build_densenet(input_shape=(224, 224, 3), n_classes=3):
    input_layer = Input(shape=input_shape)
    densenet121 = DenseNet121(include_top=False, weights='imagenet', input_tensor=input_layer)
    x = GlobalAveragePooling2D()(densenet121.output)
    x = Dropout(0.5)(x)
    x = Dense(n_classes, activation='softmax')(x)
    
    model = Model(input_layer, x)
    model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=3e-4), metrics=['accuracy'])
    return model

densenet = build_densenet()
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

周先森爱吃素

你的鼓励是我坚持创作的不懈动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值