tf2中feature_columns与keras model的结合使用

分为两种情况:

1、tensorflow 最新高版本的情况下(比如在tf>=2.4)

2、在tensorflow==2.0的情况下

在tensorflow >=2.4的情况下:


import tensorflow as tf
from tensorflow.keras import layers



def train_save_model():
    genre_vocab_list = ["female", "male"]
    level_vocab_list = ["low", "middle", "up"]

    feature_cols = []
    feature_inputs = {}

    # 1
    age_col = tf.feature_column.numeric_column(key="age", default_value=0)

    feature_cols.append(age_col)
    feature_inputs["age"] = layers.Input(shape=(1,), name="age", dtype=tf.float32)
    # 2
    genre_col = tf.feature_column.categorical_column_with_vocabulary_list(key="genre", vocabulary_list=genre_vocab_list)
    genre_embedding = tf.feature_column.embedding_column(genre_col, dimension=16)

    feature_cols.append(genre_embedding)
    feature_inputs["genre"] = layers.Input((1,), name="genre", dtype=tf.string)

    # 3
    level_col = tf.feature_column.categorical_column_with_vocabulary_list(key="level", vocabulary_list=level_vocab_list)
    level_embedding = tf.feature_column.embedding_column(level_col, dimension=16)

    feature_cols.append(level_embedding)
    feature_inputs["level"] = layers.Input((1,), name="level", dtype=tf.string)

    #
    feature_layer = layers.DenseFeatures(feature_cols)
    features = feature_layer(feature_inputs)
    # features = layers.DenseFeatures(feature_cols)(feature_inputs)
    print(features.shape, features)
    h = layers.Dense(100, activation="relu")(features)
    h = layers.Dropout(rate=0.2)(h)
    h = layers.Dense(100, activation="relu")(h)
    logits = layers.Dense(1, activation="sigmoid")(h)

    model = tf.keras.Model(feature_inputs, logits)

    model.compile("adam", loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), metrics=["accuracy"])

    age_data = tf.constant([[12], [25], [30], [50]])
    genre_data = tf.constant(["female", "female", "male", "male"])
    level_data = tf.constant(["low", "middle", "middle", "up"])

    y = tf.constant([0, 1, 1, 1])

    model.fit({"age": age_data, "genre": genre_data, "level": level_data}, y, epochs=20, batch_size=2)

    model.save("saved_model", save_format="tf")
    print(model.summary())
    print(model.predict({"age": age_data, "genre": genre_data, "level": level_data}))



def predict():
    # 加载模型

    model = tf.keras.models.load_model("saved_model")
    # model = tf.saved_model.load("")

    age_data = tf.constant([[12], [25], [30], [50]])
    genre_data = tf.constant(["female", "female", "male", "male"])
    level_data = tf.constant(["low", "middle", "middle", "up"])

    print(model.predict({"age": age_data, "genre": genre_data, "level": level_data}))

在tensorflow==2.0.4的情况下:

import tensorflow as tf
from tensorflow.keras import layers


class FeatureColModel(tf.keras.Model):

    def __init__(self):
        super(FeatureColModel, self).__init__()
        feature_cols = self.build_feature_columns()
        self.feature_layer = tf.keras.layers.DenseFeatures(feature_cols)
        self.dense_1 = layers.Dense(100, activation="relu")
        self.dense_2 = layers.Dense(100, activation="relu")
        self.dropout =layers.Dropout(0.2)
        self.dense_3 = layers.Dense(1, activation="sigmoid")




    def build_feature_columns(self):
        genre_vocab_list = ["female", "male"]
        level_vocab_list = ["low", "middle", "up"]

        feature_cols = []
        feature_inputs = {}

        # 1
        age_col = tf.feature_column.numeric_column(key="age", default_value=0)

        feature_cols.append(age_col)
        feature_inputs["age"] = layers.Input(shape=(1,), name="age", dtype=tf.float32)
        # 2
        genre_col = tf.feature_column.categorical_column_with_vocabulary_list(key="genre",
                                                                              vocabulary_list=genre_vocab_list)
        genre_embedding = tf.feature_column.embedding_column(genre_col, dimension=16)

        feature_cols.append(genre_embedding)
        feature_inputs["genre"] = layers.Input((1,), name="genre", dtype=tf.string)

        # 3
        level_col = tf.feature_column.categorical_column_with_vocabulary_list(key="level",
                                                                              vocabulary_list=level_vocab_list)
        level_embedding = tf.feature_column.embedding_column(level_col, dimension=16)

        feature_cols.append(level_embedding)
        feature_inputs["level"] = layers.Input((1,), name="level", dtype=tf.string)
        return feature_cols


    def call(self, inputs, training=None, mask=None):

        features = self.feature_layer(inputs)
        h = self.dense_1(features)
        h = self.dropout(h)
        h = self.dense_2(h)
        logits = self.dense_3(h)
        return logits


def train_custom_model():
    age_data = tf.constant([[12], [25], [30], [50]])
    genre_data = tf.constant(["female", "female", "male", "male"])
    level_data = tf.constant(["low", "middle", "middle", "up"])

    y = tf.constant([0, 1, 1, 1])

    model = FeatureColModel()
    model.compile("adam", loss=tf.keras.losses.BinaryCrossentropy(), metrics=["accuracy"])
    model.fit({"age": age_data, "genre": genre_data, "level": level_data}, y, epochs=20, batch_size=2)

    print(model.evaluate({"age": age_data, "genre": genre_data, "level": level_data}, y))
    print(model.predict({"age": age_data, "genre": genre_data, "level": level_data}))
    model.save("custom_saved_model")


def predict_custom_model():
    age_data = tf.constant([[12], [25], [30], [50]])
    genre_data = tf.constant(["female", "female", "male", "male"])
    level_data = tf.constant(["low", "middle", "middle", "up"])

    model = tf.keras.models.load_model("custom_saved_model")
    print(model.predict({"age": age_data, "genre": genre_data, "level": level_data}))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值