Tensorflow2.0 keras 子类式多输入多输出

在这里插入图片描述

1.关键代码

在定义好输入层、输出层后使用类 配置inputs outputs参数(数组)

网络模型搭建

class WideDeepModel(tf.keras.models.Model):
    def __init__(self):
        super(WideDeepModel, self).__init__()
        self.hidden1_layer = tf.keras.layers.Dense(30, activation='relu')
        self.hidden2_layer = tf.keras.layers.Dense(30, activation='relu')
        self.output_layer1 = tf.keras.layers.Dense(1)
        self.output_layer2 = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None, mask=None):
        """完成模型的正向计算"""

        input_wide = inputs[0]  # 输入1
        input_deep = inputs[1]  # 输入2

        hidden1 = self.hidden1_layer(input_deep)
        hidden2 = self.hidden2_layer(hidden1)
        concat = tf.keras.layers.concatenate([input_wide, hidden2])
        output1 = self.output_layer1(concat) # 输出1
        output2 = self.output_layer2(hidden2) # 输出2

        return [output1, output2] # 输出组合
# 构建网络
model = WideDeepModel()

model.build(input_shape=[(None, 5), (None, 6)])
print(model.layers)
model.summary()

完整代码:

import pprint
import sys

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)
print(sys.version_info)

for module in mpl, np, pd, sklearn, keras, tf:
    print(module.__name__, module.__version__)

from sklearn.datasets import fetch_california_housing

# 1.加载数据集 波士顿房价预测
housing = fetch_california_housing()
print(housing.DESCR)
print(housing.data.shape)
print(housing.target.shape)

pprint.pprint(housing.data[:5])
pprint.pprint(housing.target[:5])

from sklearn.model_selection import train_test_split

# 2.拆分数据集
#   训练集与测试集拆分
x_train_all, x_test, y_train_all, y_test = train_test_split(housing.data,
                                                            housing.target,
                                                            random_state=7,
                                                            test_size=0.20)
# 训练集与验证集的拆分
x_train, x_valid, y_train, y_valid = train_test_split(
    x_train_all, y_train_all, random_state=11, test_size=0.20)

print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
print(x_test.shape, y_test.shape)

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()

# 3、数据预处理 数据集的归一化
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)


# 4、网络模型的搭建
# 子类API
class WideDeepModel(tf.keras.models.Model):
    def __init__(self):
        super(WideDeepModel, self).__init__()
        self.hidden1_layer = tf.keras.layers.Dense(30, activation='relu')
        self.hidden2_layer = tf.keras.layers.Dense(30, activation='relu')
        self.output_layer1 = tf.keras.layers.Dense(1)
        self.output_layer2 = tf.keras.layers.Dense(1)

    def call(self, inputs, training=None, mask=None):
        """完成模型的正向计算"""

        input_wide = inputs[0]  # 输入1
        input_deep = inputs[1]  # 输入2

        hidden1 = self.hidden1_layer(input_deep)
        hidden2 = self.hidden2_layer(hidden1)
        concat = tf.keras.layers.concatenate([input_wide, hidden2])
        output1 = self.output_layer1(concat)
        output2 = self.output_layer2(hidden2)

        return [output1, output2]


# 构建网络
model = WideDeepModel()

model.build(input_shape=[(None, 5), (None, 6)])
print(model.layers)
model.summary()

# 5、模型的编译  设置损失函数 优化器
model.compile(loss='mean_squared_error',
              optimizer='adam')

# 6、设置回调函数
callbacks = [tf.keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3)]

# 7、训练网络
x_train_scaled_wide = x_train_scaled[:, :5]
x_train_scaled_deep = x_train_scaled[:, 2:]

x_valid_scaled_wide = x_valid_scaled[:, :5]
x_valid_scaled_deep = x_valid_scaled[:, 2:]

x_test_scaled_wide = x_test_scaled[:, :5]
x_test_scaled_deep = x_test_scaled[:, 2:]

history = model.fit([x_train_scaled_wide, x_train_scaled_deep],
                    [y_train, y_train],
                    validation_data=(
                        [x_valid_scaled_wide, x_valid_scaled_deep],
                        [y_valid, y_valid]),
                    epochs=20,
                    callbacks=callbacks)


# 8、绘制训练过程数据
def plot_learning_curves(hst):
    pd.DataFrame(hst.history).plot()
    plt.grid(True)
    plt.gca().set_ylim(0, 1)
    plt.show()


plot_learning_curves(history)

# 9.验证数据
model.evaluate([x_test_scaled_wide, x_test_scaled_deep], [y_test, y_test])

  • 7
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

廷益--飞鸟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值