paper:Spatial Transformer Networks
在Theano框架中,STN算法已经被封装成API,可以直接调用。tensorflow实现见文章最后。
1、空间变换器的结构:
这是一个可微分的模块,它在单个前向传递期间将空间变换应用于要素图,其中变换以特定输入为条件,从而生成单个输出要素图。对于多通道输入,对每个通道应用相同的扭曲。为简单起见,在本节中我们考虑每个变换器的单个变换和单个输出,但是我们可以推广到多个变换,如实验中所示。空间变换器机制分为三个部分,如上图所示。按计算顺序,首先定位网络采用输入特征映射,并通过若干隐藏层输出空间变换的参数应该应用于要素图 - 这给出了输入的条件转换。然后,使用预测的变换参数来创建采样网格,该采样网格是应该对输入图进行采样以产生变换输出的一组点。这是由网格生成器完成的,如Sect。最后,将特征图和采样网格作为采样器的输入,生成从网格点输入采样的输出图。 这三个部件的组合形成空间变换器。
原理上,一个feature map上学一个变换参数出来,这个参数作用到feature map上得到一个采样器G,然后用G对输入的feature map做采样,就得到了输出V,也就是V上的每个点由U上的点进行采样得到。
2、一个层加进来之后,应该用多少个这样的层呢?这个层和其他层用什么样的连接方式呢?
1. 每个channel可以有自己单独对应的一个stn参数,这样可以用不同的spatial transform来描述feature的空间变换
2. 一个channel可以同时连多个stn,用来处理图片中有多个目标时的情况
这个第2点看上去是个比较糟糕的情况,一个画面中如果有多个目标,每个目标的形变可能都不一样,那么用同样的stn对全图做变化是不太合理的,但是在并不知道图里有多少个目标的情况下,只能设置一个固定的值。
3、连接方式
实验里每隔几层conv就放一个stn,这样就是在feature上做spatial transform了,做了可视化之后可以发现stn不止做空间变化,还有crop的效果,类似attention,所以在运算上也会有些加速。
4、stn并行
多个stn并行作用在同一个feature map上的效果,从结果上看stn变多了对结果还是有帮助的,这个原因解释为更多的stn可以更好的对不同part做spatial transform并且关注不同的区域特征
5、定位网络函数f loc()可以采用任何形式,例如全连接网络或卷积网络,但应包括最终回归层以产生 变换参数θ
6、
STN的tensorflow实现
import tensorflow as tf
from scipy import ndimage
import numpy as np
import matplotlib.pyplot as plt
import cv2
def transformer(U,theta,out_size,name='SpatialTransformer',**kwargs):
print('begin-transformer')
#tf.stack()矩阵拼接函数
#得到拼接后的矩阵shape,并全部置为1
#tf.expand_dims()在第axis位置增加一个维度,这里axis=1
# 't' is a tensor of shape [2]
#shape(expand_dims(t, 0)) ==> [1, 2]
#shape(expand_dims(t, 1)) ==> [2, 1]
#shape(expand_dims(t, -1)) ==> [2, 1]
# 't2' is a tensor of shape [2, 3, 5]
#shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
#shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
#shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]
#n_repeats是什么?
def _repeat(x,n_repeats):
with tf.variable_scope('_repeat'):
rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])),1),[1,0])
rep = tf.cast(rep,'int32')
#将x reshape为n行1列后,再与rep做计算。
x = tf.matmul(tf.reshape(x,(-1,1)),rep)
return tf.reshape(x,[-1])
#插值函数
def _interpolate(im,x,y,out_size):
with tf.variable_scope('_interpolate'):
num_batch = tf.shape(im)[0]
height = tf.shape(im)[1]
width = tf.shape(im)[2]
channels = tf.shape(im)[3]
x = tf.cast(x,'float32')
y = tf.cast(y,'float32')
height_f = tf.cast(height,'float32')
width_f = tf.cast(height,'float32')
out_height = out_size[0]
out_width = out_size[1]
zero = tf.zeros([],dtype='int32')
max_y = tf.cast(tf.shape(im)[1] - 1,'int32')
max_x = tf.cast(tf.shape(im)[2] - 1,'int32')
x = (x + 1.0)*(width_f) / 2.0
x = (x + 1.0)*(height_f) / 2.0
x0 = tf.cast(tf.floor(x),'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y),'int32')
y1 = y0 + 1
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
dim2 = width
dim1 = width*height
base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
base_y0 = base + y0*dim2
base_y1 = base + y1*dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
# use indices to lookup pixels in the flat image and restore
# channels dim
im_flat = tf.reshape(im, tf.stack([-1, channels]))
im_flat = tf.cast(im_flat, 'float32')
Ia = tf.gather(im_flat, idx_a)
Ib = tf.gather(im_flat, idx_b)
Ic = tf.gather(im_flat, idx_c)
Id = tf.gather(im_flat, idx_d)
# and finally calculate interpolated values
x0_f = tf.cast(x0, 'float32')
x1_f = tf.cast(x1, 'float32')
y0_f = tf.cast(y0, 'float32')
y1_f = tf.cast(y1, 'float32')
wa = tf.expand_dims(((x1_f-x) * (y1_f-y)), 1)
wb = tf.expand_dims(((x1_f-x) * (y-y0_f)), 1)
wc = tf.expand_dims(((x-x0_f) * (y1_f-y)), 1)
wd = tf.expand_dims(((x-x0_f) * (y-y0_f)), 1)
output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return output
#生成网格矩阵
def _meshgrid(height,width):
print('begin--meshgrid')
with tf.variable_scope('_meshgrid'):
x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0]))
print('meshgrid_x_t_ok')
y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1),
tf.ones(shape=tf.stack([1, width])))
print('meshgrid_y_t_ok')
x_t_flat = tf.reshape(x_t, (1, -1))
y_t_flat = tf.reshape(y_t, (1, -1))
print('meshgrid_flat_t_ok')
ones = tf.ones_like(x_t_flat)
print('meshgrid_ones_ok')
print(x_t_flat)
print(y_t_flat)
print(ones)
grid = tf.concat( [x_t_flat, y_t_flat, ones],0)
print ('over_meshgrid')
return grid
#映射回去,即转换
def _transform(theta,input_dim,out_size):
print('_transform')
with tf.variable_scope('_transform'):
num_batch = tf.shape(input_dim)[0]
height = tf.shape(input_dim)[1]
width = tf.shape(input_dim)[2]
num_channels = tf.shape(input_dim)[3]
theta = tf.reshape(theta, (-1, 2, 3))
theta = tf.cast(theta, 'float32')
# grid of (x_t, y_t, 1), eq (1) in ref [1]
height_f = tf.cast(height, 'float32')
width_f = tf.cast(width, 'float32')
out_height = out_size[0]
out_width = out_size[1]
grid = _meshgrid(out_height, out_width)
grid = tf.expand_dims(grid, 0)
grid = tf.reshape(grid, [-1])
grid = tf.tile(grid, tf.stack([num_batch]))
grid = tf.reshape(grid, tf.stack([num_batch, 3, -1]))
#tf.batch_matrix_diag
# Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
print('begin--batch--matmul')
T_g = tf.matmul(theta, grid)
x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1])
y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1])
x_s_flat = tf.reshape(x_s, [-1])
y_s_flat = tf.reshape(y_s, [-1])
input_transformed = _interpolate(
input_dim, x_s_flat, y_s_flat,
out_size)
output = tf.reshape(
input_transformed, tf.stack([num_batch, out_height, out_width, num_channels]))
print('over_transformer')
return output
with tf.variable_scope(name):
output = _transform(theta,U,out_size)
return output
def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
with tf.variable_scope(name):
num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
indices = [[i]*num_transforms for i in xrange(num_batch)]
input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
return transformer(input_repeated, thetas, out_size)
im=ndimage.imread('cat1.jpg')
im=im/255.
im=im.reshape(1,1200,1600,3)
im=im.astype('float32')
print('img-over')
out_size=(600,800)
batch=np.append(im,im,axis=0)
batch=np.append(batch,im,axis=0)
num_batch=3
x=tf.placeholder(tf.float32,[None,1200,1600,3])
x=tf.cast(batch,'float32')
print('begin---')
with tf.variable_scope('spatial_transformer_0'):
n_fc=6
w_fc1=tf.Variable(tf.Variable(tf.zeros([1200*1600*3,n_fc]),name='W_fc1'))
initial=np.array([[0.5,0,0],[0,0.5,0]])
initial=initial.astype('float32')
initial=initial.flatten()
b_fc1=tf.Variable(initial_value=initial,name='b_fc1')
h_fc1=tf.matmul(tf.zeros([num_batch,1200*1600*3]),w_fc1)+b_fc1
print(x,h_fc1,out_size)
h_trans=transformer(x,h_fc1,out_size)
sess=tf.Session()
sess.run(tf.global_variables_initializer())
y=sess.run(h_trans,feed_dict={x:batch})
plt.imshow(y[0])
plt.show()