tensorflow2 GUMBEL-SOFTMAX


import tensorflow as tf
import tensorflow_probability as tfp


# 生成 OneHotCategorical
def dist_from_h(h, N, K, z_logit_clip, mode):
	logits_separated = tf.reshape(h, [-1, N, K])
	logits_separated_mean_zero = logits_separated - tf.reduce_mean(logits_separated, axis=-1, keepdims=True)
	if z_logit_clip is not None and mode == 'train':
		c = z_logit_clip
		logits = tf.clip_by_value(logits_separated_mean_zero, -c, c)
	else:
		logits = logits_separated_mean_zero
	
	if logits.shape[0] == 1:
		logits = tf.squeeze(logits, 0)
	
	return tfp.distributions.OneHotCategorical(logits=logits)


def sample_q(k, p_dist, temp, z_dim, mode):
	if mode == 'train':
		z_dist = tfp.distributions.RelaxedOneHotCategorical(temp, logits=p_dist.logits)
		z_NK = z_dist.sample((k, ))
	elif mode == 'eval':
		z_NK = p_dist.sample((k, ))
	return tf.reshape(z_NK, (k, -1, z_dim))


# 计算 KL 散度
def kl_q_p(p_dist, q_dist, kl_min):
	kl_separated = tfp.distributions.kl_divergence(p_dist, q_dist)
	
	if len(kl_separated.shape) < 2:
		kl_separated = tf.expand_dims(kl_separated, 0)
		
	kl_minibatch = tf.reduce_mean(kl_separated, axis=0, keepdims=True)
	
	if kl_min > 0:
		kl_lower_bounded = tf.maximum(kl_minibatch, kl_min)
		kl = tf.reduce_sum(kl_lower_bounded)
	else:
		kl = tf.reduce_sum(kl_minibatch)
	
	return kl


N = 2
K = 5
k = 3
z_logit_clip = 1
z_dim =10 
temp = 0.1
kl_min = 0.07
mode = 'train'

# p(x) 是目标分布,q(x)是去匹配的分布
logits_p = tf.random.normal([120])
logits_q = tf.random.normal([120])
dist_p = dist_from_h(logits_p, N, K, z_logit_clip, mode)
dist_q = dist_from_h(logits_q, N, K, z_logit_clip, mode)

print(dist_p)
print(dist_q)

# 计算 KL 散度
kl = kl_q_p(dist_p, dist_p, kl_min)

print(kl)

# 采样
sample  = sample_q(k, dist_p, temp, z_dim, mode)
print(sample)

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值