自然场景文本处理论文整理(1)Spatial Transformer Networks

paper:Spatial Transformer Networks
在Theano框架中,STN算法已经被封装成API,可以直接调用。tensorflow实现见文章最后。
1、空间变换器的结构:
这里写图片描述
这是一个可微分的模块,它在单个前向传递期间将空间变换应用于要素图,其中变换以特定输入为条件,从而生成单个输出要素图。对于多通道输入,对每个通道应用相同的扭曲。为简单起见,在本节中我们考虑每个变换器的单个变换和单个输出,但是我们可以推广到多个变换,如实验中所示。空间变换器机制分为三个部分,如上图所示。按计算顺序,首先定位网络采用输入特征映射,并通过若干隐藏层输出空间变换的参数应该应用于要素图 - 这给出了输入的条件转换。然后,使用预测的变换参数来创建采样网格,该采样网格是应该对输入图进行采样以产生变换输出的一组点。这是由网格生成器完成的,如Sect。最后,将特征图和采样网格作为采样器的输入,生成从网格点输入采样的输出图。 这三个部件的组合形成空间变换器。

原理上,一个feature map上学一个变换参数出来,这个参数作用到feature map上得到一个采样器G,然后用G对输入的feature map做采样,就得到了输出V,也就是V上的每个点由U上的点进行采样得到。

2、一个层加进来之后,应该用多少个这样的层呢?这个层和其他层用什么样的连接方式呢?
1. 每个channel可以有自己单独对应的一个stn参数,这样可以用不同的spatial transform来描述feature的空间变换
2. 一个channel可以同时连多个stn,用来处理图片中有多个目标时的情况
这个第2点看上去是个比较糟糕的情况,一个画面中如果有多个目标,每个目标的形变可能都不一样,那么用同样的stn对全图做变化是不太合理的,但是在并不知道图里有多少个目标的情况下,只能设置一个固定的值。

3、连接方式
实验里每隔几层conv就放一个stn,这样就是在feature上做spatial transform了,做了可视化之后可以发现stn不止做空间变化,还有crop的效果,类似attention,所以在运算上也会有些加速。
这里写图片描述

4、stn并行
多个stn并行作用在同一个feature map上的效果,从结果上看stn变多了对结果还是有帮助的,这个原因解释为更多的stn可以更好的对不同part做spatial transform并且关注不同的区域特征

5、定位网络函数f loc()可以采用任何形式,例如全连接网络或卷积网络,但应包括最终回归层以产生 变换参数θ

6、

STN的tensorflow实现
 import tensorflow as tf
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
import cv2

def transformer(U,theta,out_size,name='SpatialTransformer',**kwargs):
    print('begin-transformer')
    #tf.stack()矩阵拼接函数
    #得到拼接后的矩阵shape,并全部置为1

    #tf.expand_dims()在第axis位置增加一个维度,这里axis=1
    # 't' is a tensor of shape [2]
    #shape(expand_dims(t, 0)) ==> [1, 2]
    #shape(expand_dims(t, 1)) ==> [2, 1]
    #shape(expand_dims(t, -1)) ==> [2, 1]
    # 't2' is a tensor of shape [2, 3, 5]
    #shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
    #shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
    #shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]
    #n_repeats是什么?
    def _repeat(x,n_repeats):
        with tf.variable_scope('_repeat'):
            rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])),1),[1,0])
            rep = tf.cast(rep,'int32')
            #将x reshape为n行1列后,再与rep做计算。
            x = tf.matmul(tf.reshape(x,(-1,1)),rep)
            return tf.reshape(x,[-1])
    #插值函数
    def _interpolate(im,x,y,out_size):
        with tf.variable_scope('_interpolate'):
            num_batch = tf.shape(im)[0]
            height = tf.shape(im)[1]
            width = tf.shape(im)[2]
            channels = tf.shape(im)[3]

            x = tf.cast(x,'float32')
            y = tf.cast(y,'float32')
            height_f = tf.cast(height,'float32')
            width_f = tf.cast(height,'float32')
            out_height = out_size[0]
            out_width = out_size[1]
            zero = tf.zeros([],dtype='int32')
            max_y = tf.cast(tf.shape(im)[1] - 1,'int32')
            max_x = tf.cast(tf.shape(im)[2] - 1,'int32')

            x = (x + 1.0)*(width_f) / 2.0
            x = (x + 1.0)*(height_f) / 2.0

            x0 = tf.cast(tf.floor(x),'int32')
            x1 = x0 + 1
            y0 = tf.cast(tf.floor(y),'int32')
            y1 = y0 + 1

            x0 = tf.clip_by_value(x0, zero, max_x)
            x1 = tf.clip_by_value(x1, zero, max_x)
            y0 = tf.clip_by_value(y0, zero, max_y)
            y1 = tf.clip_by_value(y1, zero, max_y)
            dim2 = width
            dim1 = width*height
            base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
            base_y0 = base + y0*dim2
            base_y1 = base + y1*dim2
            idx_a = base_y0 + x0
            idx_b = base_y1 + x0
            idx_c = base_y0 + x1
            idx_d = base_y1 + x1

            # use indices to lookup pixels in the flat image and restore
            # channels dim
            im_flat = tf.reshape(im, tf.stack([-1, channels]))
            im_flat = tf.cast(im_flat, 'float32')
            Ia = tf.gather(im_flat, idx_a)
            Ib = tf.gather(im_flat, idx_b)
            Ic = tf.gather(im_flat, idx_c)
            Id = tf.gather(im_flat, idx_d)

            # and finally calculate interpolated values
            x0_f = tf.cast(x0, 'float32')
            x1_f = tf.cast(x1, 'float32')
            y0_f = tf.cast(y0, 'float32')
            y1_f = tf.cast(y1, 'float32')
            wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1)
            wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1)
            wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1)
            wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1)
            output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
            return output
    #生成网格矩阵
    def _meshgrid(height,width):
        print('begin--meshgrid')
        with tf.variable_scope('_meshgrid'):

            x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
                            tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
            print('meshgrid_x_t_ok')
            y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
                            tf.ones(shape=tf.stack([1, width])))
            print('meshgrid_y_t_ok')
            x_t_flat = tf.reshape(x_t, (1, -1))
            y_t_flat = tf.reshape(y_t, (1, -1))
            print('meshgrid_flat_t_ok')
            ones = tf.ones_like(x_t_flat)
            print('meshgrid_ones_ok')
            print(x_t_flat)
            print(y_t_flat)
            print(ones)

            grid = tf.concat( [x_t_flat, y_t_flat, ones],0)
            print ('over_meshgrid')
            return grid
    #映射回去,即转换
    def _transform(theta,input_dim,out_size):
        print('_transform')

        with tf.variable_scope('_transform'):
            num_batch = tf.shape(input_dim)[0]
            height = tf.shape(input_dim)[1]
            width = tf.shape(input_dim)[2]
            num_channels = tf.shape(input_dim)[3]
            theta = tf.reshape(theta, (-1, 2, 3))
            theta = tf.cast(theta, 'float32')

            # grid of (x_t, y_t, 1), eq (1) in ref [1]
            height_f = tf.cast(height, 'float32')
            width_f = tf.cast(width, 'float32')
            out_height = out_size[0]
            out_width = out_size[1]
            grid = _meshgrid(out_height, out_width)
            grid = tf.expand_dims(grid, 0)
            grid = tf.reshape(grid, [-1])
            grid = tf.tile(grid, tf.stack([num_batch]))
            grid = tf.reshape(grid, tf.stack([num_batch, 3, -1]))
            #tf.batch_matrix_diag
            # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
            print('begin--batch--matmul')
            T_g = tf.matmul(theta, grid)
            x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
            y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
            x_s_flat = tf.reshape(x_s, [-1])
            y_s_flat = tf.reshape(y_s, [-1])

            input_transformed = _interpolate(
                input_dim, x_s_flat, y_s_flat,
                out_size)

            output = tf.reshape(
                input_transformed, tf.stack([num_batch, out_height, out_width, num_channels]))
            print('over_transformer')
            return output

    with tf.variable_scope(name):
        output = _transform(theta,U,out_size)
        return output

def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):

    with tf.variable_scope(name):
        num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
        indices = [[i]*num_transforms for i in xrange(num_batch)]
        input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
        return transformer(input_repeated, thetas, out_size)



im=ndimage.imread('cat1.jpg')
im=im/255.
im=im.reshape(1,1200,1600,3)
im=im.astype('float32')
print('img-over')


out_size=(600,800)
batch=np.append(im,im,axis=0)
batch=np.append(batch,im,axis=0)
num_batch=3

x=tf.placeholder(tf.float32,[None,1200,1600,3])
x=tf.cast(batch,'float32')
print('begin---')


with tf.variable_scope('spatial_transformer_0'):
    n_fc=6
    w_fc1=tf.Variable(tf.Variable(tf.zeros([1200*1600*3,n_fc]),name='W_fc1'))
    initial=np.array([[0.5,0,0],[0,0.5,0]])
    initial=initial.astype('float32')
    initial=initial.flatten()

    b_fc1=tf.Variable(initial_value=initial,name='b_fc1')

    h_fc1=tf.matmul(tf.zeros([num_batch,1200*1600*3]),w_fc1)+b_fc1
    print(x,h_fc1,out_size)
    h_trans=transformer(x,h_fc1,out_size)


sess=tf.Session()
sess.run(tf.global_variables_initializer())
y=sess.run(h_trans,feed_dict={x:batch})
plt.imshow(y[0])
plt.show()
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值