import tensorflow as tf
print(tf.__version__)
list_arr = [9, 8, 6, 5]
value_arr = [0, 1, 2, 3]
tf_look_up = tf.constant(list_arr, dtype=tf.int64)
tf_value_arr = tf.constant(value_arr, dtype=tf.int64)
table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(tf_look_up, tf_value_arr), 0)
ph_vals = tf.constant([8, 5], dtype=tf.int64)
ph_idx = table.lookup(ph_vals)
with tf.compat.v1.Session() as sess:
sess.run(tf.tables_initializer())
sess.run(tf.initialize_all_variables())
res = sess.run(ph_idx)
print(res)
input = [
"harden|james|curry",
"wrestbrook|harden|durant|hardenx",
"paul|towns",
""
]
weight = [
"0.4,0.3,0.1",
"0.4,0.3,0.1,0.1",
"0.4,0.3",
""
]
TAG_SET = ["harden", "james", "curry", "durant", "paul", "towns", "wrestbrook"]
class IndexValueEmbedding(object):
default_value = "xxx-default"
def __init__(self, field_name, category_list, embedding_size):
self.field_name = field_name
self.category_list = category_list
self.embedding_size = embedding_size
self._init_dict()
self._init_embedding()
def _init_embedding(self):
with tf.variable_scope("index_value"):
self.embedding_params = tf.get_variable(name=self.field_name,
initializer=tf.truncated_normal(
[len(self._tags), self.embedding_size]))
pass
def _init_dict(self):
self._tags = [self.default_value]
self._tags.extend(self.category_list)
self.table = tf.contrib.lookup.index_table_from_tensor(mapping=self._tags, default_value=0)
def get_avg_embedding(self, input_indexes, input_weights, sep1=",", sep2=","):
input, weight = self._preprocess(input_indexes, input_weights)
tags = self._sparse_from_string_array(input, sep1)
wgt = tf.string_split(tf.string_strip(weight), sep2, skip_empty=True)
wgt_number = tf.string_to_number(wgt.values, tf.float32)
mask = tf.equal(tags.values, 0)
wgt_number1 = tf.where(mask, wgt_number, tf.zeros_like(wgt_number) + tf.constant(0.000001, tf.float32))
sparse_wgt = tf.SparseTensor(wgt.indices, wgt_number1, wgt.dense_shape)
embedded_tags = tf.nn.embedding_lookup_sparse(self.embedding_params, sp_ids=tags, sp_weights=sparse_wgt,
combiner="mean")
return embedded_tags
def _preprocess(self, input, weight):
input = tf.map_fn(lambda x: tf.cond(tf.equal(tf.string_strip(x), ""), lambda: self.default_value, lambda: x),
elems=tf.constant(input, tf.string))
weight = tf.map_fn(lambda x: tf.cond(tf.equal(tf.string_strip(x), ""), lambda: "1", lambda: x),
elems=tf.constant(weight, tf.string))
return input, weight
def _sparse_from_string_array(self, input_keys, sep):
input_keys_trims = tf.string_strip(input_keys)
split_tags = tf.string_split(input_keys_trims, sep, skip_empty=False)
return tf.SparseTensor(indices=split_tags.indices,
values=self.table.lookup(split_tags.values),
dense_shape=split_tags.dense_shape)
index_value_emb = IndexValueEmbedding("cate", TAG_SET, 4)
avg = index_value_emb.get_avg_embedding(input, weight, "|")
# mask_res = tf.where(mask, tags.values, tf.zeros_like(tags.values))
# new_tags = tf.SparseTensor(indices=tags.indices, values=mask_res, dense_shape=tags.dense_shape)
with tf.compat.v1.Session() as sess:
sess.run(tf.tables_initializer())
sess.run(tf.initialize_all_variables())
print("avg_res\n", sess.run(avg))