参考:
1. 论文:winner-take-all-autoencoders.pdf
2. 代码:
a. full connect WTA-AE
b. Conv-WTA-AE
简单理解:
spatial sparsity: 对卷积得到的feature-map-tensor(shape=[N,H,W,C]), 沿每个channel, 是一个H*W的张量,仅仅保留这个H*W的张量上的最大值,其余数值元素置零。
lifetime sparsity:是在同一个批次的多个样本中,跨样本的赢者通吃。在tensorflow中,每个样本对应一个[H,W,C]的张量,lifetime sparsity是指不同样本的特征tensor的同一个channel进行比较,保留最大的K个数值。
比如:feature_map_tensor.shape=[N,H,W,C],即,batch_size=N, 每次训练输入N个样本,每个样本都有一个[H,W,C]的张量特征,lifetime sparsity 是比较这N个样本的张量特征的同一个channel的数值,并筛选最大的K个保留数值,其余的置为0。
经过spatial sparsity之后,每个样本特征的每个channel仅保留了一个数值(最大值)。
lifetime sparsity在多个样本张量的同一个channel中,比较,从N个最大值中再筛选最大的K个保留数值,其余的置为0.
主要代码:
def _spatial_sparsity(self, h):
"""
h: 待执行WTA的feature map
功能: 保留feature-map-tensor的每个channel上的最大的K个最大值,其余的元素置零。
这里K=1,即,每个channel仅保留最大值,其余数值置零。
相当于逐个channel执行maxpooling,然后保持原来的形状和元素位置。
"""
shape = tf.shape(h)
n = shape[0] # batch-size
c = shape[3] # channel-num
# step1: 在feature map的每个channel上寻找该channel上的第k大的数值,作为阈值
h_t = tf.transpose(h, [0, 3, 1, 2]) # 张量变形1,shape=[n, c, h, w]
h_r = tf.reshape(h_t, tf.stack([n, c, -1])) # 张量变形2,把每个channel的二维张量拉成一维矢量,
# 拉直。shape变成[n, c, h*w]
th, _ = tf.nn.top_k(h_r, k=1) # top_k: 沿张量的最后一个维度寻找K个最大值,
# 返回这K个最大值,及其索引。n, c, 1
# 这里是沿[n,c,h*w]的最后一个维度寻找1个最大值,
# 作为阈值,th.shape=[n,c,1]
# top_k,Finds values and indices of the k largest entries for the last dimension.
# step2:
th_r = tf.reshape(th, tf.stack([n, 1, 1, c])) #把阈值设置成与feature-map一样的维度,shape=[n,1,1,c]
drop = tf.where(h < th_r,tf.zeros(shape, tf.float32), tf.ones(shape, tf.float32))
# 求掩码:大于阈值的系数设置为1,否则设置为0,后后面执行feature-map元素处理做准备
# spatially dropped & winner
# step3: 执行feature-map操作,小于阈值的置零,大于阈值的保持
h_wta = h*drop
return h_wta, tf.reshape(th, tf.stack([n, c]))
#不知道reshape阈值张量th后,后面有什么用途,th.shape=[n, c]
def _lifetime_sparsity(self, h, winner, rate):
shape = tf.shape(winner)
n = shape[0]
c = shape[1]
k = tf.cast(rate * tf.cast(n, tf.float32), tf.int32)
winner = tf.transpose(winner) # c, n
th_k, _ = tf.nn.top_k(winner, k) # c, k
shape_t = tf.stack([c, n])
drop = tf.where(winner < th_k[:,k-1:k], # c, n
tf.zeros(shape_t, tf.float32), tf.ones(shape_t, tf.float32))
drop = tf.transpose(drop) # n, c
return h * tf.reshape(drop, tf.stack([n, 1, 1, c]))