之前在写tf模型的时候,对于category类型的特征,经常是预处理成id,然后才输入到模型中去,category->id的映射通常是独立与tf代码的。tf.lookup模块中提供了使用tensorflow原生api将category特征映射为id的方法,本文将介绍这些方法。
tf.lookup模块中有两类方法:
- Initializer:负责构建 category -> id 映射表
- tf.lookup.KeyValueTensorInitializer: 通过显式指定 category->id构建映射表
- tf.lookup.TextFileInitializer:从文件读取数据构建映射表
- Table:复制执行 category -> id 的映射
- tf.lookup.StaticHashTable:通过给定的映射表,进行映射,如果没有找到,则返回默认值
- tf.lookup.StaticVocabularyTable:通过给定的映射表,进行映射,如果没有找到,则会映射为 hash(<term>) % num_oov_buckets + vocab_size
输入
import tensorflow as tf
keys_tensor = tf.constant(['牛奶', '鸡蛋'])
vals_tensor = tf.constant([3, 4])
input_tensor = tf.constant(['鸡蛋', '白菜'])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
out = table.lookup(input_tensor)
with tf.Session() as sess:
sess.run(tf.tables_initializer())
print(sess.run(out))
输出
输入
""" profile_feats.txt
hello
world
"""
init = tf.lookup.TextFileInitializer(
filename='profile_feats.txt',
key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
table = tf.lookup.StaticHashTable(init, -1)
out = table.lookup(tf.constant('world'))
with tf.Session() as sess:
sess.run(tf.tables_initializer())
print(sess.run(out))
输出
TF中的哈希处理列&哈希冲突处理