最近在复现一篇论文,其中需要条件embedding。具体的要求为:如果某个词汇为主题词,那么该词的embedding是在主题词汇表,否则就在总的词汇表中。
代码如下:
import tensorflow as tf
import numpy as np
#总的词汇表与主题词汇表的映射关系,如[11,1],11表示词汇在总的词汇表的index,1表示在主题词汇表中的index。
a=tf.constant([[11,1],[12,2],[13,3],[14,4],[15,5],[16,6]],dtype=tf.int32)
# 文章的词汇,有3个batch。
# b=tf.constant([[1,11,13],[3,16,6],[5,6,2]],dtype=tf.int32)
b=tf.constant([[0,13],[1,16],[0,8]],dtype=tf.int32)
emb=tf.get_variable(name='embb',shape=[20,5],dtype=tf.float32)
topic_emb=tf.get_variable(name='topic',shape=[10,5],dtype=tf.float32)
a_shape=a.get_shape()[0]
b_shape=b.get_shape()
size=b_shape[0]*b_shape[1]
def look_step(g,index,single,k1):
t=tf.cond(tf.equal(single,a[k1,0]),
lambda :(True,k1),
lambda :(False,-1))
return tf.logical_or(g,t[0]),tf.maximum(index,t[1]),single,k1+1
def _look(single):
g=False
k1=0
index=-1
g,index,*_=tf.while_loop(
cond=lambda g,index,single,k1:k1<a_shape,
body=look_step,
loop_vars=[g,index,single,k1]
)
return g,index
k=tf.constant(1)
def loop_step(k,topic,matrix):
g,index=_look(topic[k])
matr=tf.cond(g,
lambda :tf.gather(topic_emb,[a[index,1]]),
lambda :tf.gather(emb,[topic[k]]))
matr=tf.reshape(matr,[1,1,-1])
matrix=tf.concat([matrix,matr],0)
return k+1,topic,matrix
g=tf.reshape(b,[-1])
# matrix=tf.Variable([[]])
matrix=tf.nn.embedding_lookup(emb,[0])
matrix=tf.reshape(matrix,[1,1,-1])
_,_,matrix=tf.while_loop(
cond=lambda k,b,*_:k<size,
body=loop_step,
loop_vars=[k,g,matrix],
shape_invariants=[k.get_shape(),tf.TensorShape([None]),tf.TensorShape([None,1,5])]
)
matrix=tf.reshape(matrix,shape=[-1,b_shape[1],5])
print(matrix)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
matrix,topics,emm=sess.run([matrix,topic_emb,emb])
print(np.array(matrix).shape)
print(matrix)
print('------')
print(topics)
print('------')
print(emm)
如果需要对其求梯度,上面的代码需要稍作优化。
import tensorflow as tf
import numpy as np
#总的词汇表与主题词汇表的映射关系,如[11,1],11表示词汇在总的词汇表的index,1表示在主题词汇表中的index。
a=tf.constant([[11,1],[12,2],[13,3],[14,4],[15,5],[16,6]],dtype=tf.int32)
# 文章的词汇,有3个batch。
# b=tf.constant([[1,11,13],[3,16,6],[5,6,2]],dtype=tf.int32)
b=tf.constant([[0,13],[1,16],[0,8]],dtype=tf.int32)
emb=tf.get_variable(name='embb',shape=[20,5],dtype=tf.float32)
topic_emb=tf.get_variable(name='topic',shape=[10,5],dtype=tf.float32)
a_shape=a.get_shape()[0]
b_shape=b.get_shape()
size=b_shape[0]*b_shape[1]
def look_step(g,index,single,k1):
t=tf.cond(tf.equal(single,a[k1,0]),
lambda :(True,k1),
lambda :(False,-1))
return tf.logical_or(g,t[0]),tf.maximum(index,t[1]),single,k1+1
def _look(single):
g=False
k1=0
index=-1
g,index,*_=tf.while_loop(
cond=lambda g,index,single,k1:k1<a_shape,
body=look_step,
loop_vars=[g,index,single,k1]
)
return g,index
def loop_step(k,topic,matrix):
g,index=_look(topic[k])
matr=tf.cond(g,
lambda :tf.gather(topic_emb,[a[index,1]]),
lambda :tf.gather(emb,[topic[k]]))
matr=tf.reshape(matr,[1,1,-1])
matrix=matrix.write(k,matr)
return k+1,topic,matrix
# matrix=tf.Variable([[]])
matrix=tf.nn.embedding_lookup(emb,[0])
matrix=tf.reshape(matrix,[1,1,-1])
loop_vars=[tf.constant(0),tf.reshape(b,[-1]),tf.TensorArray(tf.float32,size=size)]
_,_,matrix=tf.while_loop(
cond=lambda k,b,*_:k<size,
body=loop_step,
loop_vars=loop_vars
)
matrix=tf.reshape(matrix.stack(),shape=[-1,b_shape[1],5])
print(matrix)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
matrix,topics,emm=sess.run([matrix,topic_emb,emb])
print(np.array(matrix).shape)
print(matrix)
print('------')
print(topics)
print('------')
print(emm)
如果大家有更好的方法实现,欢迎在下方评论区评论。。