赢者通吃自编码器(WTA-AE)

参考:
1. 论文:winner-take-all-autoencoders.pdf
2. 代码:
a. full connect WTA-AE
b. Conv-WTA-AE
简单理解:
        spatial sparsity: 对卷积得到的feature-map-tensor(shape=[N,H,W,C]), 沿每个channel, 是一个H*W的张量,仅仅保留这个H*W的张量上的最大值,其余数值元素置零。
        lifetime sparsity:是在同一个批次的多个样本中,跨样本的赢者通吃。在tensorflow中,每个样本对应一个[H,W,C]的张量,lifetime sparsity是指不同样本的特征tensor的同一个channel进行比较,保留最大的K个数值。

   比如:feature_map_tensor.shape=[N,H,W,C],即,batch_size=N, 每次训练输入N个样本,每个样本都有一个[H,W,C]的张量特征,lifetime sparsity 是比较这N个样本的张量特征的同一个channel的数值,并筛选最大的K个保留数值,其余的置为0。


     经过spatial sparsity之后,每个样本特征的每个channel仅保留了一个数值(最大值)。

     lifetime sparsity在多个样本张量的同一个channel中,比较,从N个最大值中再筛选最大的K个保留数值,其余的置为0.


   
主要代码:   

 def _spatial_sparsity(self, h):
    """
    h: 待执行WTA的feature map
    功能: 保留feature-map-tensor的每个channel上的最大的K个最大值,其余的元素置零。
           这里K=1,即,每个channel仅保留最大值,其余数值置零。
           相当于逐个channel执行maxpooling,然后保持原来的形状和元素位置。

    """
        shape = tf.shape(h)
        n = shape[0] # batch-size
        c = shape[3] # channel-num

        # step1: 在feature map的每个channel上寻找该channel上的第k大的数值,作为阈值
        h_t = tf.transpose(h, [0, 3, 1, 2]) # 张量变形1,shape=[n, c, h, w]
        h_r = tf.reshape(h_t, tf.stack([n, c, -1])) # 张量变形2,把每个channel的二维张量拉成一维矢量,
                                                    # 拉直。shape变成[n, c, h*w]

        th, _ = tf.nn.top_k(h_r, k=1) # top_k: 沿张量的最后一个维度寻找K个最大值,
                                      # 返回这K个最大值,及其索引。n, c, 1
                                      # 这里是沿[n,c,h*w]的最后一个维度寻找1个最大值,
                                      # 作为阈值,th.shape=[n,c,1]
        # top_k,Finds values and indices of the k largest entries for the last dimension.
        
        # step2: 
        th_r = tf.reshape(th, tf.stack([n, 1, 1, c])) #把阈值设置成与feature-map一样的维度,shape=[n,1,1,c]
        drop = tf.where(h < th_r,tf.zeros(shape, tf.float32), tf.ones(shape, tf.float32))
        # 求掩码:大于阈值的系数设置为1,否则设置为0,后后面执行feature-map元素处理做准备

        # spatially dropped & winner
        # step3: 执行feature-map操作,小于阈值的置零,大于阈值的保持
        h_wta = h*drop  
        return h_wta, tf.reshape(th, tf.stack([n, c])) 
        #不知道reshape阈值张量th后,后面有什么用途,th.shape=[n, c]

 def _lifetime_sparsity(self, h, winner, rate):
        shape = tf.shape(winner)
        n = shape[0]
        c = shape[1]
        k = tf.cast(rate * tf.cast(n, tf.float32), tf.int32)

        winner = tf.transpose(winner) # c, n
        th_k, _ = tf.nn.top_k(winner, k) # c, k

        shape_t = tf.stack([c, n])
        drop = tf.where(winner < th_k[:,k-1:k], # c, n
            tf.zeros(shape_t, tf.float32), tf.ones(shape_t, tf.float32))
        drop = tf.transpose(drop) # n, c
        return h * tf.reshape(drop, tf.stack([n, 1, 1, c]))


 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值