背景
最近在构建模型的时候需要用到用户的历史行为特征,这时候要和对应的特征共享embedding信息,这里就使用到了tf.feature_column.shared_embedding_column,但是网上针对这个api如何调用ckpt文件里的embedding信息并没有很详细的说明,所以自己整理了一下话不多说,看代码:
# 然后将上面的变量存入到ckpt_color(checkpoint文件)中
# 在shared_embedding_column中会用到保存的ckpt_color文件
ckpt_path = '' # 你指定的ckpt文件
color_vocab = ['a', 'b', 'c', 'd']
color_emb = [[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]]
color_emb = tf.Variable([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]], name='color_emb')
tf.train.saver()
color_data = {'color_clicked': [['a', 'b'], ['c', 'd']]
'color': [['c'], ['a']]}
color_column = tf.feature_column.categorical_column_with_vocabulary_list('color', vocabulary_list=vocab_list)
color_column2 = tf.feature_column.categorical_column_with_vocabulary_list('color_cliked', vocabulary_list=vocab_list)
color_emb_column = tf.feature_column.shared_embedding_columns([color_column, color_column2], 2,
ckpt_to_load_from=ckpt_path,
tensor_name_in_ckpt='color_emb')
with tf.Session() as sess:
sess.run(tf.feature_column.input_layer(color_data, color_emb_column))
# 输出结果和输入对应:
# 输入: [['c', ['a', 'b']], ['a', ['c', 'd']]]
# 输出: [[3., 4., 1.5, 2.5], [1., 2., 3.5, 4.5]]
说明:
tf.feature_column.categorical_column_with_vocabulary_list
- 很简单,使用vocabulary list的feature_column进行初始化
- 一般业务应用中,使用tf.feature_column.categorical_column_with_vocabulary_file 的比较多,因为对应的vocab可能一个lsit写不下,只好写在file里面,但是这一块没有什么区别
tf.feature_column.shared_embedding_columns
tf.compat.v1.feature_column.shared_embedding_columns( categorical_columns, dimension, combiner='mean', initializer=None, shared_embedding_collection_name=None, ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None, trainable=True, use_safe_embedding_lookup=True )
-
categorical_columns
: 需要进行共享embedding信息的categorical columns 列表combiner
: 针对数组类型的特征域(多值特征),对应的embedding信息的处理方法,默认支持mean
,sqrtn
,sum
三种initializer
: embedding信息初始化方法,可以通过这个函数对embedding信息进行随机初始化ckpt_to_load_from
: 从外部导入embedding信息的时候使用到的保存embedding信息的ckpt文件tensor_name_in_ckpt
: 对应的embedding信息的变量名- ckpt文件中embedding信息要和categorical_feature_columns中vocab_list/vocab_file中的顺序是element-wise对应的,这样在对数据进行embedding匹配的时候才能正确对应