keras 实现多任务学习

def deep_multi_model(feature_dim, cvr_label_dim, profit_label_dim):
    inputs = Input(shape=(feature_dim,))
    dense_1 = Dense(512, activation='relu')(inputs)
    dense_2 = Dense(384, activation='relu')(dense_1)
    dense_3 = Dense(256, activation='relu')(dense_2)
    drop_1 = Dropout(0.2)(dense_3)
    dense_4 = Dense(128, activation='relu')(drop_1)
    dense_5 = Dense(64, activation='relu')(dense_4)

    output_1 = Dense(32, activation='relu')(dense_5)
    output_cvr = Dense(cvr_label_dim, activation='softmax', name='output_cvr')(output_1)

    output_2 = Dense(16, activation='relu')(dense_5)
    output_profit = Dense(profit_label_dim, activation='softmax', name='output_profit')(output_2)

    # 模型有两个输出 output_cvr, output_profit
    model = Model(inputs=inputs, outputs=[output_cvr, output_profit])
    model.summary()

    # 模型有两个 loss, 都是 categorical_crossentropy
    # loss 的 key 需要和模型的 output 层的 name 保持一致
    model.compile(optimizer='adam',
              loss={'output_cvr': 'categorical_crossentropy', 'output_profit': 'categorical_crossentropy'},
              loss_weights={'output_cvr':1, 'output_profit': 0.3},
              metrics=[categorical_accuracy])
    
    return model


# 产生训练数据的生成器
# 模型只有一个 input 有两个 output,所以 yield 格式为如下
def generate_arrays(X_train, y_train_cvr_label, y_train_profit_label):
    while True:
        for x, y_cvr, y_profit in zip(X_train, y_train_cvr_label, y_train_profit_label):
            yield (x[np.newaxis, :], {'output_cvr': y_cvr[np.newaxis, :], 'output_profit': y_profit[np.newaxis, :]})


# fit_generator 进行 fit 训练
def train_multi(X_train, y_train_cvr_label, y_train_profit_label, X_test, y_test_cvr_label, y_test_profit_label):
    feature_dim = X_train.shape[1]
    cvr_label_dim = y_train_cvr_label.shape[1]
    profit_label_dim = y_train_profit_label.shape[1]
    
    model = deep_multi_model(feature_dim, cvr_label_dim, profit_label_dim)
    
    model.summary()
    early_stopping = EarlyStopping(monitor='val_loss', patience=15, verbose=0)
    
    
    model.fit_generator(generate_arrays(X_train, y_train_cvr_label, y_train_profit_label),
                        steps_per_epoch=1024, 
                        epochs=100, 
                        validation_data=generate_arrays(X_test, y_test_cvr_label, y_test_profit_label), 
                        validation_steps=1024, 
                        callbacks=[early_stopping])

    return model


  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值