tensorflow 实现条件embedding

最近在复现一篇论文,其中需要条件embedding。具体的要求为:如果某个词汇为主题词,那么该词的embedding是在主题词汇表,否则就在总的词汇表中。

代码如下:

import tensorflow as tf
import numpy as np
#总的词汇表与主题词汇表的映射关系,如[11,1],11表示词汇在总的词汇表的index,1表示在主题词汇表中的index。
a=tf.constant([[11,1],[12,2],[13,3],[14,4],[15,5],[16,6]],dtype=tf.int32)
# 文章的词汇,有3个batch。
# b=tf.constant([[1,11,13],[3,16,6],[5,6,2]],dtype=tf.int32)
b=tf.constant([[0,13],[1,16],[0,8]],dtype=tf.int32)

emb=tf.get_variable(name='embb',shape=[20,5],dtype=tf.float32)
topic_emb=tf.get_variable(name='topic',shape=[10,5],dtype=tf.float32)
a_shape=a.get_shape()[0]
b_shape=b.get_shape()
size=b_shape[0]*b_shape[1]

def look_step(g,index,single,k1):
    t=tf.cond(tf.equal(single,a[k1,0]),
            lambda :(True,k1),
            lambda :(False,-1))
    return tf.logical_or(g,t[0]),tf.maximum(index,t[1]),single,k1+1
    
def _look(single):
    g=False
    k1=0
    index=-1
    g,index,*_=tf.while_loop(
        cond=lambda g,index,single,k1:k1<a_shape,
        body=look_step,
        loop_vars=[g,index,single,k1]
    )
    return g,index

k=tf.constant(1)
def loop_step(k,topic,matrix):
    g,index=_look(topic[k])
    matr=tf.cond(g,
                lambda :tf.gather(topic_emb,[a[index,1]]),
                lambda :tf.gather(emb,[topic[k]]))
    matr=tf.reshape(matr,[1,1,-1])
    matrix=tf.concat([matrix,matr],0)

    return k+1,topic,matrix

g=tf.reshape(b,[-1])
# matrix=tf.Variable([[]])
matrix=tf.nn.embedding_lookup(emb,[0])
matrix=tf.reshape(matrix,[1,1,-1])
_,_,matrix=tf.while_loop(
    cond=lambda k,b,*_:k<size,
    body=loop_step,
    loop_vars=[k,g,matrix],
    shape_invariants=[k.get_shape(),tf.TensorShape([None]),tf.TensorShape([None,1,5])]

)

matrix=tf.reshape(matrix,shape=[-1,b_shape[1],5])
print(matrix)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)

matrix,topics,emm=sess.run([matrix,topic_emb,emb])
print(np.array(matrix).shape)
print(matrix)

print('------')
print(topics)
print('------')
print(emm)

如果需要对其求梯度,上面的代码需要稍作优化。

import tensorflow as tf
import numpy as np
#总的词汇表与主题词汇表的映射关系,如[11,1],11表示词汇在总的词汇表的index,1表示在主题词汇表中的index。
a=tf.constant([[11,1],[12,2],[13,3],[14,4],[15,5],[16,6]],dtype=tf.int32)
# 文章的词汇,有3个batch。
# b=tf.constant([[1,11,13],[3,16,6],[5,6,2]],dtype=tf.int32)
b=tf.constant([[0,13],[1,16],[0,8]],dtype=tf.int32)

emb=tf.get_variable(name='embb',shape=[20,5],dtype=tf.float32)
topic_emb=tf.get_variable(name='topic',shape=[10,5],dtype=tf.float32)
a_shape=a.get_shape()[0]
b_shape=b.get_shape()
size=b_shape[0]*b_shape[1]

def look_step(g,index,single,k1):
    t=tf.cond(tf.equal(single,a[k1,0]),
            lambda :(True,k1),
            lambda :(False,-1))
    return tf.logical_or(g,t[0]),tf.maximum(index,t[1]),single,k1+1
    
def _look(single):
    g=False
    k1=0
    index=-1
    g,index,*_=tf.while_loop(
        cond=lambda g,index,single,k1:k1<a_shape,
        body=look_step,
        loop_vars=[g,index,single,k1]
    )
    return g,index

def loop_step(k,topic,matrix):
    g,index=_look(topic[k])
    matr=tf.cond(g,
                lambda :tf.gather(topic_emb,[a[index,1]]),
                lambda :tf.gather(emb,[topic[k]]))
    matr=tf.reshape(matr,[1,1,-1])
    matrix=matrix.write(k,matr)
    return k+1,topic,matrix

# matrix=tf.Variable([[]])
matrix=tf.nn.embedding_lookup(emb,[0])
matrix=tf.reshape(matrix,[1,1,-1])
loop_vars=[tf.constant(0),tf.reshape(b,[-1]),tf.TensorArray(tf.float32,size=size)]
_,_,matrix=tf.while_loop(
    cond=lambda k,b,*_:k<size,
    body=loop_step,
    loop_vars=loop_vars
)

matrix=tf.reshape(matrix.stack(),shape=[-1,b_shape[1],5])
print(matrix)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)

matrix,topics,emm=sess.run([matrix,topic_emb,emb])
print(np.array(matrix).shape)
print(matrix)

print('------')
print(topics)
print('------')
print(emm)

如果大家有更好的方法实现,欢迎在下方评论区评论。。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值