STN 代码(Spatial Transformer Networks)

# 定义localization网络
def get_localization_network(self):
    localization = tf.keras.Sequential([  # 卷积序列
        Conv2D(8, kernel_size=7, input_shape=self.img_shape,
                      activation="relu", kernel_initializer="he_normal"),
        MaxPool2D(strides=2),
        Conv2D(10, kernel_size=5, activation="relu", kernel_initializer="he_normal"),
        MaxPool2D(strides=2),
    ])
    return localization

# 定义3 * 2仿射矩阵的回归器
def get_affine_params(self):
    # 定义一个初始值为 [1, 0, 0, 0, 1, 0] 的偏置项.
    # | a  b  tx |
    # | c  d  ty |
    # a 和 d 是缩放因子,b 和 c 是旋转因子,tx 和 ty 是平移因子
    # 这个向量的前两个元素是a和b,中间两个元素是c和d,最后两个元素是tx和ty
    # 在神经网络中,偏置项是模型的可学习参数之一,用于在激活函数之前添加一个常数偏移
    # 在这个特定的矩阵中,它被初始化为一个单位矩阵,即无缩放、无旋转、无平移
    output_bias = tf.keras.initializers.Constant([1, 0, 0, 0, 1, 0])
    fc_loc = tf.keras.Sequential([
        layers.Dense(32, activation="relu", kernel_initializer="he_normal"),
        layers.Dense(3 * 2, kernel_initializer="zeros", bias_initializer=output_bias)
    ])

    return fc_loc

# 获取图像中某个坐标的像素值
def get_pixel_value(self, img, x, y):
    """
    Utility function to get pixel value for coordinate
    vectors x and y from a  4D tensor image.
    从4D张量图像中获取坐标向量x和y的像素值的效用函数
    Input
    -----
    - img: tensor of shape (B, H, W, C)
    - x: flattened tensor of shape (B*H*W,)
    - y: flattened tensor of shape (B*H*W,)
    Returns
    -------
    - output: tensor of shape (B, H, W, C)
    """
    shape = tf.shape(x)
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]

    batch_idx = tf.range(0, batch_size)
    batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
    b = tf.tile(batch_idx, (1, height, width))

    indices = tf.stack([b, y, x], 3)

    return tf.gather_nd(img, indices)

# 仿射网格生成器
def affine_grid_generator(self, height, width, theta):
    """
    This function returns a sampling grid, which when
    used with the bilinear sampler on the input feature
    map, will create an output feature map that is an
    affine transformation [1] of the input feature map.
    该函数返回一个采样网格,当与双线性采样器一起在输入特征映射上使用时,
    将创建一个输出特征映射,该输出特征映射是输入特征映射的仿射变换。
    Input
    -----
    - height: desired height of grid/output. Used
      to downsample or upsample.
    - width: desired width of grid/output. Used
      to downsample or upsample.
    - theta: affine transform matrices of shape (num_batch, 2, 3).
      For each image in the batch, we have 6 theta parameters of
      the form (2x3) that define the affine transformation T.
    Returns
    -------
    - normalized grid (-1, 1) of shape (num_batch, 2, H, W).
      The 2nd dimension has 2 components: (x, y) which are the
      sampling points of the original image for each point in the
      target image.
    Note
    ----
    [1]: the affine transformation allows cropping, translation,
         and isotropic scaling.
    """
    num_batch = tf.shape(theta)[0]

    # create normalized 2D grid
    x = tf.linspace(-1.0, 1.0, width)
    y = tf.linspace(-1.0, 1.0, height)
    x_t, y_t = tf.meshgrid(x, y)

    # flatten
    x_t_flat = tf.reshape(x_t, [-1])
    y_t_flat = tf.reshape(y_t, [-1])

    # reshape to [x_t, y_t , 1] - (homogeneous form)
    ones = tf.ones_like(x_t_flat)
    sampling_grid = tf.stack([x_t_flat, y_t_flat, ones])

    # repeat grid num_batch times
    sampling_grid = tf.expand_dims(sampling_grid, axis=0)
    sampling_grid = tf.tile(sampling_grid, tf.stack([num_batch, 1, 1]))

    # cast to float32 (required for matmul)
    theta = tf.cast(theta, 'float32')
    sampling_grid = tf.cast(sampling_grid, 'float32')

    # transform the sampling grid - batch multiply
    batch_grids = tf.matmul(theta, sampling_grid)
    # batch grid has shape (num_batch, 2, H*W)

    # reshape to (num_batch, H, W, 2)
    batch_grids = tf.reshape(batch_grids, [num_batch, 2, height, width])

    return batch_grids

# 双线性采样器
def bilinear_sampler(self, img, x, y):
    """
    Performs bilinear sampling of the input images according to the
    normalized coordinates provided by the sampling grid. Note that
    the sampling is done identically for each channel of the input.
    To test if the function works properly, output image should be
    identical to input image when theta is initialized to identity
    transform.
    根据采样网格提供的归一化坐标对输入图像进行双线性采样。
    请注意,对输入的每个通道进行相同的采样。
    为了测试函数是否正常工作,当theta初始化为恒等变换时,
    输出图像应该与输入图像相同。
    Input
    -----
    - img: batch of images in (B, H, W, C) layout.
    - grid: x, y which is the output of affine_grid_generator.
    Returns
    -------
    - out: interpolated images according to grids. Same size as grid.
    """
    H = tf.shape(img)[1]
    W = tf.shape(img)[2]
    max_y = tf.cast(H - 1, 'int32')
    max_x = tf.cast(W - 1, 'int32')
    zero = tf.zeros([], dtype='int32')

    # rescale x and y to [0, W-1/H-1]
    x = tf.cast(x, 'float32')
    y = tf.cast(y, 'float32')
    x = 0.5 * ((x + 1.0) * tf.cast(max_x - 1, 'float32'))
    y = 0.5 * ((y + 1.0) * tf.cast(max_y - 1, 'float32'))

    # grab 4 nearest corner points for each (x_i, y_i)
    x0 = tf.cast(tf.floor(x), 'int32')
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), 'int32')
    y1 = y0 + 1

    # clip to range [0, H-1/W-1] to not violate img boundaries
    x0 = tf.clip_by_value(x0, zero, max_x)
    x1 = tf.clip_by_value(x1, zero, max_x)
    y0 = tf.clip_by_value(y0, zero, max_y)
    y1 = tf.clip_by_value(y1, zero, max_y)

    # get pixel value at corner coords
    Ia = self.get_pixel_value(img, x0, y0)
    Ib = self.get_pixel_value(img, x0, y1)
    Ic = self.get_pixel_value(img, x1, y0)
    Id = self.get_pixel_value(img, x1, y1)

    # recast as float for delta calculation
    x0 = tf.cast(x0, 'float32')
    x1 = tf.cast(x1, 'float32')
    y0 = tf.cast(y0, 'float32')
    y1 = tf.cast(y1, 'float32')

    # calculate deltas
    wa = (x1 - x) * (y1 - y)
    wb = (x1 - x) * (y - y0)
    wc = (x - x0) * (y1 - y)
    wd = (x - x0) * (y - y0)

    # add dimension for addition
    wa = tf.expand_dims(wa, axis=3)
    wb = tf.expand_dims(wb, axis=3)
    wc = tf.expand_dims(wc, axis=3)
    wd = tf.expand_dims(wd, axis=3)

    # compute output
    out = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])

    return out

# 空间变换网络
def stn(self, x):
    localization = self.get_localization_network()
    fc_loc = self.get_affine_params()

    xs = localization(x)
    xs = tf.reshape(xs, (-1, 10 * 3 * 3))
    theta = fc_loc(xs)
    theta = tf.reshape(theta, (-1, 2, 3))

    grid = self.affine_grid_generator(28, 28, theta)
    x_s = grid[:, 0, :, :]
    y_s = grid[:, 1, :, :]
    x = self.bilinear_sampler(x, x_s, y_s)

    return x

输入输出都是尺寸相同的图片

  • 10
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值