Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
    optimizer=optimizers.Adam(lr=learning_rate),
    loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
    loss_weights={"main": 0.5, "auxiliary": 0.5},
    metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
   train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

Keras官方文档
generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

  • a tuple (inputs, targets)
  • a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence
class Batch_generator(Sequence):
	"""
	用于产生batch_1, batch_2(记住是numpy.array格式转换)
	"""
	y_batch = {'main':batch_1,'auxiliary':batch_2}
	return  X_batch, y_batch

# or in another way

def batch_generator():
	"""
	用于产生batch_1, batch_2(记住是numpy.array格式转换)
	"""
	yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):
如果是多输出(多任务)的时候,这里的target是字典类型
如果是多输出(多任务)的时候,这里的target是字典类型
如果是多输出(多任务)的时候,这里的target是字典类型

Reference:
[1] How to use fit_generator with multiple outputs in Keras
[2] keras:怎样使用 fit_generator 来训练多个不同类型的输出

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
以下是使用bert4keras实现BERT实体关系联合抽取的Python代码示例: ```python import json import numpy as np from bert4keras.backend import keras, K from bert4keras.layers import Loss from bert4keras.models import build_transformer_model from bert4keras.optimizers import Adam from bert4keras.tokenizers import Tokenizer from keras.layers import Dense from keras.models import Model # 模型参数 maxlen = 128 epochs = 10 batch_size = 16 learning_rate = 2e-5 categories = ["疾病和诊断", "影像检查", "实验室检验", "药物"] num_classes = len(categories) # BERT配置 config_path = '/path/to/bert_config.json' checkpoint_path = '/path/to/bert_model.ckpt' dict_path = '/path/to/vocab.txt' # 加载数据 def load_data(filename): D = [] with open(filename, encoding='utf-8') as f: for l in f: l = json.loads(l) d = {'text': l['text'], 'spo_list': []} for spo in l['spo_list']: for o in spo['object']: d['spo_list'].append((spo['subject'], spo['predicate'], o)) D.append(d) return D # 加载数据集 train_data = load_data('/path/to/train_data.json') valid_data = load_data('/path/to/valid_data.json') test_data = load_data('/path/to/test_data.json') # 建立分词器 tokenizer = Tokenizer(dict_path, do_lower_case=True) class data_generator: """数据生成器 """ def __init__(self, data, batch_size=32, shuffle=True): self.data = data self.batch_size = batch_size self.shuffle = shuffle self.steps = len(self.data) // self.batch_size if len(self.data) % self.batch_size != 0: self.steps += 1 def __len__(self): return self.steps def __iter__(self): while True: idxs = list(range(len(self.data))) if self.shuffle: np.random.shuffle(idxs) X1, X2, S, Y = [], [], [], [] for i in idxs: d = self.data[i] text = d['text'][:maxlen] x1, x2 = tokenizer.encode(text) s = np.zeros(len(text)) for spo in d['spo_list']: subject = spo[0][:maxlen] object = spo[2][:maxlen] start = text.find(subject) if start != -1: end = start + len(subject) - 1 s[start:end+1] = 1 # 构建标注数据 predicate = spo[1] y = np.zeros(num_classes) y[categories.index(predicate)] = 1 X1.append(x1) X2.append(x2) S.append(s) Y.append(y) if len(X1) == 0: continue X1 = keras.preprocessing.sequence.pad_sequences(X1, maxlen=maxlen) X2 = keras.preprocessing.sequence.pad_sequences(X2, maxlen=maxlen) S = keras.preprocessing.sequence.pad_sequences(S, maxlen=maxlen) Y = np.array(Y) yield [X1, X2, S], Y # 构建模型 bert_model = build_transformer_model( config_path, checkpoint_path, model='bert', return_keras_model=False, ) output_layer = 'Transformer-%s-FeedForward-Norm' % (bert_model.num_hidden_layers - 1) output = bert_model.get_layer(output_layer).output output = Dense(num_classes, activation='sigmoid')(output) model = Model(bert_model.input, output) model.summary() # 损失函数 class MultiLoss(Loss): """多任务学习的损失函数 """ def compute_loss(self, inputs, mask=None): y_true, y_pred = inputs y_true = K.cast(y_true, y_pred.dtype) loss = K.binary_crossentropy(y_true, y_pred) return loss loss = MultiLoss().compute_loss # 优化器 optimizer = Adam(learning_rate) # 编译模型 model.compile(loss=loss, optimizer=optimizer) # 训练模型 train_generator = data_generator(train_data, batch_size) valid_generator = data_generator(valid_data, batch_size) test_generator = data_generator(test_data, batch_size) model.fit_generator( train_generator.forfit(), steps_per_epoch=len(train_generator), epochs=epochs, validation_data=valid_generator.forfit(), validation_steps=len(valid_generator) ) # 评估模型 model.evaluate_generator(test_generator.forfit(), steps=len(test_generator)) # 保存模型 model.save_weights('/path/to/model.weights') ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值