子类API实现wide&deep模型——网络层构建
以下代码展示如何用子类API实现wide&deep模型的网络层的构建,数据是使用的sklearn中的加州房价预测数据集。wide以及deep模型使用同样的输出。
代码展示:
class WideDeepModel(keras.models.Model):
def __init__(self):
super(WideDeepModel, self).__init__()
#定义模型层次
self.hidden1_layer = keras.layers.Dense(30, activation= 'relu')
self.hidden2_layer = keras.layers.Dense(30, activation= 'relu')
self.output_layer = keras.layers.Dense(1)
def call(self, input):
#完成模型的正向计算
hidden1 = self.hidden1_layer(input)
hidden2 = self.hidden2_layer(hidden1)
concat = keras.layers.concatenate([input, hidden2])
output = self.output_layer(concat)
return output
#第一种写法:
model = WideDeepModel()
model.build(input_shape = (None, 8))
"""
#第二种写法
model = kears.model.Sequentail(
WideDeepModel(),
)
"""