tensorflow == 2.4
# coding:utf-8
"""
@author: liu
@File: tf_feature_col.py
@CreateTime: 2021/7/5
"""
import tensorflow as tf
from tensorflow.keras import layers
genre_vocab_list = ["female", "male"]
level_vocab_list = ["low", "middle", "up"]
feature_cols = []
feature_inputs = {}
# 1
age_col = tf.feature_column.numeric_column(key="age", default_value=0)
feature_cols.append(age_col)
feature_inputs["age"] = layers.Input(shape=(1, ), name="age", dtype=tf.float32)
# 2
genre_col = tf.feature_column.categorical_column_with_vocabulary_list(key="genre", vocabulary_list=genre_vocab_list)
genre_embedding = tf.feature_column.embedding_column(genre_col, dimension=16)
feature_cols.append(genre_embedding)
feature_inputs["genre"] = layers.Input((1, ), name="genre", dtype=tf.string)
# 3
level_col = tf.feature_column.categorical_column_with_vocabulary_list(key="level", vocabulary_list=level_vocab_list)
level_embedding = tf.feature_column.embedding_column(level_col, dimension=16)
feature_cols.append(level_embedding)
feature_inputs["level"] = layers.Input((1,), name="level", dtype=tf.string)
#
features = layers.DenseFeatures(feature_cols)(feature_inputs)
print(features.shape, features)
h = layers.Dense(100, activation="relu")(features)
h = layers.Dropout(rate=0.2)(h)
h = layers.Dense(100, activation="relu")(h)
logits = layers.Dense(1, activation="sigmoid")(h)
model = tf.keras.Model(feature_inputs, logits)
model.compile("adam", loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), metrics=["accuracy"])
age_data = tf.constant([12, 25, 30, 50])
genre_data = tf.constant(["female", "female", "male", "male"])
level_data = tf.constant(["low", "middle", "middle", "up"])
y = tf.constant([0, 1, 1, 1])
model.fit({"age": age_data, "genre": genre_data, "level": level_data}, y, epochs=20, batch_size=2)
model.save("saved_model")
# 加载模型
model = tf.keras.models.load_model("saved_model")
age_data = tf.constant([12, 25, 30, 50])
genre_data = tf.constant(["female", "female", "male", "male"])
level_data = tf.constant(["low", "middle", "middle", "up"])
print(model.predict({"age": age_data, "genre": genre_data, "level": level_data}))