# coding: utf-8
from tensorflow.keras.models import Model
from tensorflow.keras import layers
from tensorflow.keras import Input
import numpy as np
from tensorflow.keras.utils import to_categorical
text_vocabulary_size = 10000
question_vocabulary_size = 10000
answer_vocabulary_size = 500
# --------------将输入嵌入长度64的向量--------------------------------
text_input = Input(shape=(None, ), dtype='int32', name='text') # 文本的长度可变
embedded_text = layers.Embedding(text_vocabulary_size, 64)(text_input)
encoded_text = layers.LSTM(32)(embedded_text) # 利用LSTM将向量编码为单个向量
# -------------问题层实例化---------------------------------------------
question_input = Input(shape=(None, ), dtype='int32', name='question')
embedded_question = layers.Embedding(question_vocabulary_size, 32)(question_input)
encoded_question = layers.LSTM(16)(embedded_question)
# --------------连接文本和问题------------------------------------------
concatenated = layers.concatenate([encoded_text, encoded_question], axis=-1)
# -------------添加一个softmax分类器--------------------------------------
answer = layers.Dense(answer_vocabulary_size, activation='softmax')(concatenated)
# -------------模型实例化--------------------------------------------------
model = Model([text_input, question_input], answer) # 指定两个输入和输出
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])
model.summary()
"""将数据输入到模型中进行训练"""
num_samples = 1000
max_length = 100
text = np.random.randint(1, text_vocabulary_size, size=(num_samples, max_length))
question = np.random.randint(1, question_vocabulary_size, size=(num_samples, max_length))
answers = np.random.randint(answer_vocabulary_size, size=num_samples)
answers = to_categorical(answers, answer_vocabulary_size) # one-hot化
# ----------------------使用输入组成的列表来拟合--------------------------
model.fit([text, question], answers,epochs=10, batch_size=128)
# ----------------------使用输入组成的字典拟合----------------------------
# model.fit({'text': text, 'question': question}, answers, epochs=10, batch_size=128)
keras函数式API——多输入模型(问答模型)
最新推荐文章于 2024-08-04 23:00:28 发布