在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出
# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
optimizer=optimizers.Adam(lr=learning_rate),
loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
loss_weights={"main": 0.5, "auxiliary": 0.5},
metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
train_gen, epochs=num_epochs, verbose=0, shuffle=True
)
看Keras官方文档:
generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either
- a tuple (inputs, targets)
- a tuple (inputs, targets, sample_weights).
Keras设计多输出(多任务)使用fit_generator的步骤如下:
根据官方文档,定义一个generator或者一个class继承Sequence
class Batch_generator(Sequence):
"""
用于产生batch_1, batch_2(记住是numpy.array格式转换)
"""
y_batch = {'main':batch_1,'auxiliary':batch_2}
return X_batch, y_batch
# or in another way
def batch_generator():
"""
用于产生batch_1, batch_2(记住是numpy.array格式转换)
"""
yield X_batch, {'main': batch_1,'auxiliary':batch_2}
重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):
如果是多输出(多任务)的时候,这里的target是字典类型
如果是多输出(多任务)的时候,这里的target是字典类型
如果是多输出(多任务)的时候,这里的target是字典类型
Reference:
[1] How to use fit_generator with multiple outputs in Keras
[2] keras:怎样使用 fit_generator 来训练多个不同类型的输出