STN的Grid generator和Sample实现

  • Localisation net输入batch_images(shape:[b,h,w,c])输出batch_theta(shape:[b,2,3]),取决于具体
    网络这里不再陈述。
  • Grid generatorSample实现如下
import tensorflow as tf
import numpy as np
H:\anaconda3\envs\tf\lib\site-packages\scipy\__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.3
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
def affine_grid(batch_theta,batch_output_shape):
    """
    batch_theta:
        shape: [b,2,3]
    batch_output_shape:
        value: [b,oh,ow,c]
        
    return:
        batch_affine_grid:
            shape: [b,oh,ow,2]
    """
    #common data
    oh = batch_output_shape[1]
    ow = batch_output_shape[2]
    oh_max = 1 - 1/oh #value_range: [0,1]
    oh_min = -oh_max #value_range: [-1,0]
    ow_max = 1 - 1/ow
    ow_min = -ow_max
    #[oh,]
    oh_lim = tf.cast(tf.linspace(oh_min,oh_max,oh),dtype=tf.float32)
    #[ow,]
    ow_lim = tf.cast(tf.linspace(ow_min,ow_max,ow),dtype=tf.float32)
    #[oh,ow] [oh,ow]
    h_mt,w_mt = tf.meshgrid(oh_lim,ow_lim,indexing='ij')
    #[oh,ow,3]
    position_hw1 = tf.concat([h_mt[...,tf.newaxis],w_mt[...,tf.newaxis],tf.ones_like(h_mt,dtype=tf.float32)[...,tf.newaxis]],axis=-1)
    #[b,oh,ow,3]
    batch_position_hw1 = tf.tile(position_hw1[tf.newaxis,...],[batch_output_shape[0],1,1,1])
    #[b,3,2]
    batch_theta_transpose = tf.transpose(batch_theta,[0,2,1])
    #[b,oh,ow,2]                                     [b,oh,ow,3]        [b,3,2]
    batch_affine_grid = tf.einsum('bhwx,bxn -> bhwn',batch_position_hw1,batch_theta_transpose)  
    
    return batch_affine_grid
def grid_sample(batch_input,batch_affine_grid):
    """
    method: bilinear
    batch_input:
        shape: [b,ih,iw,c]
    batch_affine_grid:
        shape: [b,oh,ow,2]
    
    return:
        batch_result_image:
            shape: [b,oh,ow,c]
    """
    #获取原始图像(batch_input)位置网格
    #[4,] value:[b,ih,iw,c]
    batch_input_shape = tf.shape(batch_input)
    #common data
    ih,iw = batch_input_shape[1],batch_input_shape[2]
    #[ih,iw] [ih,iw]
    h_mt,w_mt = tf.meshgrid(tf.range(batch_input_shape[1],dtype=tf.float32),tf.range(batch_input_shape[2],dtype=tf.float32),indexing='ij')
    #[ih,iw,2]
    position_hw = tf.concat([h_mt[...,tf.newaxis],w_mt[...,tf.newaxis]],axis=-1)
    
    #先归一化batch_affine_grid再规范到原始图像(batch_input)大小
    #[4,] [b,oh,ow,2]
    batch_affine_grid_shape = tf.shape(batch_affine_grid)
    #common data
    oh,ow = batch_affine_grid_shape[1],batch_affine_grid_shape[2]
    oh_max,ow_max = tf.cast(1 - 1/oh,dtype=tf.float32),tf.cast(1 - 1/ow,dtype=tf.float32)
    oh_min,ow_min = -oh_max,-ow_max
    #[b,oh,ow,2] 归一化                     [2,]                                                    [2,]
    batch_affine_grid = (batch_affine_grid-tf.convert_to_tensor([oh_min,ow_min],dtype=tf.float32))/tf.convert_to_tensor([oh_max-oh_min,ow_max-ow_min],dtype=tf.float32)
    #[b,oh,ow,2] 规范到原始图像(batch_input)大小 [2,]
    batch_affine_grid = batch_affine_grid * tf.convert_to_tensor([ih-1,iw-1],dtype=tf.float32)
    
    #计算各网格点像素值
    #method1:计算量小
    #batch_affine_grid中的值代表在原图中的位置[h,w]
    #[[w1h2,w2h2],
    # [w1h1,w2h1]]
    h = batch_affine_grid[...,0:1] #[b,oh,ow,1]
    w = batch_affine_grid[...,1:2] #[b,oh,ow,1]
    
    h1 = tf.cast(tf.floor(batch_affine_grid[...,0:1]),dtype=tf.int32) #[b,oh,ow,1]
    h2 = h1 + 1 #[b,oh,ow,1]
    w1 = tf.cast(tf.floor(batch_affine_grid[...,1:2]),dtype=tf.int32) #[b,oh,ow,1]
    w2 = w1 + 1 #[b,oh,ow,1]
    
    h1 = tf.clip_by_value(h1,0,ih-1) #[b,oh,ow,1]
    h2 = tf.clip_by_value(h2,0,ih-1) #[b,oh,ow,1]
    w1 = tf.clip_by_value(w1,0,iw-1) #[b,oh,ow,1]
    w2 = tf.clip_by_value(w2,0,iw-1) #[b,oh,ow,1]
    
    #get pixel value
    h1w1 = tf.concat([h1,w1],axis=-1) #[b,oh,ow,2]
    h1w2 = tf.concat([h1,w2],axis=-1) #[b,oh,ow,2]
    h2w1 = tf.concat([h2,w1],axis=-1) #[b,oh,ow,2]
    h2w2 = tf.concat([h2,w2],axis=-1) #[b,oh,ow,2]
    #                            [b,ih,iw,c] [b,oh,ow,2]
    fh1w1 = tf.cast(tf.gather_nd(batch_input,h1w1,batch_dims=1),dtype=tf.float32) #[b,oh,ow,c]
    fh1w2 = tf.cast(tf.gather_nd(batch_input,h1w2,batch_dims=1),dtype=tf.float32) #[b,oh,ow,c]
    fh2w1 = tf.cast(tf.gather_nd(batch_input,h2w1,batch_dims=1),dtype=tf.float32) #[b,oh,ow,c]
    fh2w2 = tf.cast(tf.gather_nd(batch_input,h2w2,batch_dims=1),dtype=tf.float32) #[b,oh,ow,c]
    
    #method1-1
    h1 = tf.cast(h1,dtype=tf.float32) #[b,oh,ow,1]
    h2 = tf.cast(h2,dtype=tf.float32) #[b,oh,ow,1]
    w1 = tf.cast(w1,dtype=tf.float32) #[b,oh,ow,1]
    w2 = tf.cast(w2,dtype=tf.float32) #[b,oh,ow,1]
    #[b,oh,ow,c]
    fP = (h2-h)*(w2-w)*fh1w1 + (h2-h)*(w-w1)*fh1w2 + (h-h1)*(w2-w)*fh2w1 + (h-h1)*(w-w1)*fh2w2
    
    # #method1-2:占用存储空间大(比method1-1大了4*2-4*c倍)
    # h1w1 = tf.cast(h1w1,dtype=tf.float32)
    # h1w2 = tf.cast(h1w2,dtype=tf.float32)
    # h2w1 = tf.cast(h2w1,dtype=tf.float32)
    # h2w2 = tf.cast(h2w2,dtype=tf.float32)
    # #[b,oh,ow,4,2]
    # grid4 = tf.stack([h1w1,h1w2,h2w1,h2w2],axis=3)
    # #[b,oh,ow,4,c]
    # fgrid4 = tf.stack([fh1w1,fh1w2,fh2w1,fh2w2],axis=3)
    # #[b,oh,ow,4,2]                           [b,oh,ow,1,2]                         [b,oh,ow,4,2]
    # distance_bhw42 = tf.maximum(0,1 - tf.abs(batch_affine_grid[...,tf.newaxis,:] - grid4))
    # #[b,oh,ow,4,c]     [b,oh,ow,4,1]             [b,oh,ow,4,1]             [b,oh,ow,4,c]
    # batch_dist_input = distance_bhw42[...,0:1] * distance_bhw42[...,1:2] * fgrid4
    # #[b,oh,ow,c]
    # fP = tf.reduce_sum(batch_dist_input,axis=[3])
    
    return fP
    
    #method2:占用存储空间大(比method1-1大了ih*iw*2-ih*iw*c倍)
    # #计算位置距离再计算像素值
    # #[b,oh,ow,ih,iw,2]                        [b,oh,ow,1,1,2]                                  [1,1,1,ih,iw,2]
    # distance_bhwhw2 = tf.maximum(0,1 - tf.abs(batch_affine_grid[...,tf.newaxis,tf.newaxis,:] - position_hw[tf.newaxis,tf.newaxis,tf.newaxis,...]))
    # #[b,oh,ow,ih,iw,c] [b,oh,ow,ih,iw,1]          [b,oh,ow,ih,iw,1]          [b,1,1,ih,iw,c] 
    # batch_dist_input = distance_bhwhw2[...,0:1] * distance_bhwhw2[...,1:2] * batch_input[:,tf.newaxis,tf.newaxis,...]
    # #[b,oh,ow,c]
    # batch_result_image = tf.reduce_sum(batch_dist_input,axis=[3,4])
    # return batch_result_image
image = tf.keras.utils.load_img('000005.jpg')
image = tf.keras.utils.img_to_array(image)
#image = tf.image.resize(image,[80,80])
batch_nums = 64
batch_image = tf.tile(image[tf.newaxis,...],[batch_nums,1,1,1])
tf.keras.utils.array_to_img(image)

在这里插入图片描述

tf.shape(batch_image)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 64, 375, 500,   3])>
thetas = tf.constant([[0.5, 0., 0.],[0., 0.5, 0.]],dtype=tf.float32)
#out_h、out_w可以不和image.height、image.width相等
out_h = 375
out_w = 500
r = affine_grid(tf.tile(thetas[tf.newaxis,...],[batch_nums,1,1]),[batch_nums,out_h,out_w,3])
r2 = grid_sample(batch_image,r)
img = tf.keras.utils.array_to_img(r2[0])
img

在这里插入图片描述

r2.shape
TensorShape([64, 375, 500, 3])
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

起名大废废

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值