tensorflow实现(indices)上采样

tensorflow实现segnet 和 learn deconv网络中的上采样。

def unpool_with_argmax(bottom, argmax, output_shape=None, name='max_unpool_with_argmax'):
    '''
    upsampling according argmax
    :param bottom: the output feature maps needed to be upsampled
    :param argmax: the indice made by tf.nn.max_pool_with_argmax()
    :param output_shape: 
    :param name:
    :return:
    '''
    with tf.name_scope(name):
        ksize = [1, 2, 2, 1]
        input_shape = bottom.get_shape().as_list()
        #  calculation new shape
        if output_shape is None:
            output_shape = (input_shape[0],
                            input_shape[1] * ksize[1],
                            input_shape[2] * ksize[2],
                            input_shape[3])
        flat_input_size = np.prod(input_shape)
        flat_output_size = np.prod(output_shape)
        bottom_ = tf.reshape(bottom, [flat_input_size])
        argmax_ = tf.reshape(argmax, [flat_input_size, 1])

        ret = tf.scatter_nd(argmax_, bottom_, [flat_output_size])

        ret = tf.reshape(ret, output_shape)
        return ret

测试

import numpy as np
import tensorflow as tf
input_data = tf.constant(np.random.rand(16, 4, 4, 3), dtype=np.float32)

x, arg = tf.nn.max_pool_with_argmax(input_data, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

def unpool_with_argmax(bottom, argmax, output_shape=None, name='max_unpool_with_argmax'):
    '''
    upsampling according argmax
    :param bottom: the output feature maps needed to be upsampled
    :param argmax: the indice made by tf.nn.max_pool_with_argmax()
    :param output_shape:
    :param name:
    :return:
    '''
    with tf.name_scope(name):
        ksize = [1, 2, 2, 1]
        input_shape = bottom.get_shape().as_list()
        #  calculation new shape
        if output_shape is None:
            output_shape = (input_shape[0],
                            input_shape[1] * ksize[1],
                            input_shape[2] * ksize[2],
                            input_shape[3])
        flat_input_size = np.prod(input_shape)
        flat_output_size = np.prod(output_shape)
        bottom_ = tf.reshape(bottom, [flat_input_size])
        argmax_ = tf.reshape(argmax, [flat_input_size, 1])

        ret = tf.scatter_nd(argmax_, bottom_, [flat_output_size])

        ret = tf.reshape(ret, output_shape)
        return ret


ret = unpool_with_argmax(x, arg)

x_2 = tf.nn.max_pool(ret, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

with tf.Session() as sess:
    x_val, arg_val, ret_val, x_2_val = sess.run([x, arg, ret, x_2])
    print x_val[0, :, :, 0]
    print "#######################################"
    print ret_val[0, :, :, 0]
    print "**************************************"
    print arg_val[0, :, :, 0]
    print x_val.shape, arg_val.shape, ret_val.shape


输出结果

[[ 0.92141378  0.83250898]
 [ 0.96589577  0.92536974]]
#######################################
[[ 0.          0.          0.83250898  0.        ]
 [ 0.92141378  0.          0.          0.        ]
 [ 0.          0.96589577  0.          0.        ]
 [ 0.          0.          0.92536974  0.        ]]
**************************************
[[12  6]
 [27 42]]
(16, 2, 2, 3) (16, 2, 2, 3) (16, 4, 4, 3)
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值