空间变换网络STN

28 篇文章 9 订阅
7 篇文章 1 订阅

出自论文Spatial Transformer Networks 

Insight:

文章提出的STN的作用类似于传统的矫正的作用。比如人脸识别中,需要先对检测的图片进行关键点检测,然后使用关键点来进行对齐操作。但是这样的一个过程是需要额外进行处理的。但是有了STN后,检测完的人脸,直接就可以做对齐操作。关键的一点就是这个矫正过程是可以进行梯度传导的。想象一下,人脸检测完了,直接使用ROI pooling取出人脸的feature map,输入STN就可以进行矫正,输出矫正后的人脸。后面还可以再接点卷积操作,直接就可以进行分类,人脸识别的训练。整个流程从理论上来说,都有梯度传导,理论上可以将检测+对齐+识别使用一个网络实现。当然实际操作中可能会有各种trick。

空间变换基础:

图像的几何变换包括透视变换和仿射变换,透视变换又称为投影变换、投射变换、投影映射,透视变换是将图片投影到一个新的视平面,它是二维(x,y)到三维(X,Y,Z)、再到另一个二维(x’,y’)空间的映射。

仿射变换又称为图像仿射映射,可以认为是透视变换的一种特殊情况,是透视变换的子集,仿射变换是从二维空间到自身的映射,是指在几何中,一个向量空间进行一次线性变换并接上一个平移,变换为另一个向量空间,也就是图像仿射变换等于图像线性变换和平移的组合。

仿射变换包括平移(translation)、旋转(rotation)、缩放(scaling)、错切(shear )四种类型:

平移和旋转两者的组合不改变图像的大小和形状,只有图像的位置(平移变换)和朝向(旋转变换)发生改变,称之为欧式变换(Euclidean transformation)或刚体变换(rigid transformation),刚性变换是最一般的变换

缩放又分为等比例缩放(uniform scaling)和非等比例缩放(non-uniform scaling),如果缩放系数为负数,则会叠加翻转(reflection,又翻译为反射、镜像),因此翻转可以看成是特殊的缩放

欧式变换和等比例缩放保持了图像外观没有变形,因此二者的组合称为相似变换(similarity transformation)

错切是保持图形上各点的某一坐标值不变,而另一坐标值关于该保持不变坐标值进行线性变换。坐标不变的轴称为依赖轴,其余坐标轴称为方向轴。错切分为水平错切和垂直错切。

2D仿射变换(affine):

平移:

旋转:

缩放:

整体:

3D透视变换(projection):

平移:

旋转:

缩放:

Opencv函数

1、getAffineTransform函数

getAffineTransform通过确认源图像中不在同一直线的三个点对应的目标图像的位置,来获取对应仿射变换矩阵,从而用该仿射变换矩阵对图像进行统一的仿射变换。

特别注意getAffineTransform只支持非等比例缩放也就是说xy的缩放是不一样的但是有些场合可能需要等比例的缩放比如换脸任务需要保证长宽比不变比如人脸识别任务如果保证长宽比,效果会更佳

调用语法

retval = cv.getAffineTransform(src, dst)

语法说明

src:源图像中三角形顶点的坐标,也就是在源图像中任找不在同一直线上的三个点,将三个点的坐标作为三个元素放到src对应列表中

dst:目标图像中相应三角形顶点的坐标,也就是三个点在变换后图像中的坐标列表,要求与源图像三个点一一对应

retval返回值:从三对对应的点计算出来的仿射变换矩阵

2、warpAffine函数

在OpenCV中,仿射变换可以通过函数warpAffine来支持,当然部分单独的函数也可以进行某个特定的变换,如缩放和旋转就有单独的变换函数。

调用语法

warpAffine(src, M, dsize, dst=None, flags=None, borderMode=None, borderValue=None)

语法说明

src:输入图像矩阵

M:2*3的仿射变换矩阵,可以自己构建,也可以通过getAffineTransform、getRotationMatrix2D等函数获取

dsize:结果图像大小,为宽和高的二元组

dst:输出结果图像,可以省略,结果图像会作为函数处理结果输出

flags:可选参数,插值方法的组合(int 类型),默认值 INTER_LINEAR

borderMode:可选参数,边界像素模式(int 类型),默认值 BORDER_CONSTANT

borderValue:可选参数,边界填充值,当borderMode为cv2.BORDER_CONSTANT时使用,默认值为None。另外当borderMode取值为cv2.BORDER_TRANSPARENT时,目标图像中与源图像中的离群值相对应的像素不被函数修改(关于离群值老猿暂还未完全弄明白,暂且存疑)

返回值:为仿射变换后的结果图像矩阵,最后的结果矩阵每个像素与原图像像素的对应关系为:

如果flags标记设置了WARP_INVERSE_MAP标记,首先使用invertAffineTransform对变换矩阵进行反转即求其逆矩阵,然后将其放入上面的公式中,而不是将M直接放入。例如如果变换矩阵是顺时针旋转30°,则设置WARP_INVERSE_MAP标记的情况下实际变换效果是逆时针旋转30°。这样做的目的是为了已知目标图像和变换方法的情况下,可以求出原图像。所以这个标记很重要

返回值:仿射变换后的结果图像

3、getRotationMatrix2D

getRotationMatrix2D用于获取一个旋转二维图像的仿射变换矩阵。

调用语法

getRotationMatrix2D(center, angle, scale)

参数语法说明

center:图像旋转的旋转参考中心点坐标二元组

angle:旋转角度,坐标原点为左上角情况下正值表示逆时针旋转

scale:等比例缩放因子

返回值

getRotationMatrix2D返回值为一2*3复合旋转变换仿射矩阵。按照OpenCV官方介绍,getRotationMatrix2D得到的矩阵为:

 

绕指定点旋转进行组合变换时,参考点p(m,n)顺时针旋转θ的组合变换的齐次坐标表示公式为:

上述公式中θ为正表示是顺时针旋转,与getRotationMatrix2D中的angle参数取值方式相反,由于cos(-θ)=cosθ,sin(-θ)=-sinθ,因此实际上getRotationMatrix2D中旋转正值的角度时对应的上述矩阵计算公式中sin函数前面的符号需要取反(正号变副号、副号变正号)。

而缩放的齐次坐标表示公式为:

 

用缩放矩阵左乘平移矩阵则可以得到顺时针旋转同时进行缩放的齐次坐标表示公式: 

 

当等比例缩放且缩放因子等于s时,上述公式中的kx、ky使用s替换。则最后的组合变换矩阵为: 

 可以看到,将getRotationMatrix2D的参数angle(逆时针旋转为正)的角度变为上述组合矩阵变换公式中的-θ(顺时针旋转为正)、getRotationMatrix2D中的center.x、center.y分别使用m、n替换,取组合变换矩阵的前2行,则二者结果等价。因此getRotationMatrix2D函数获得的变换矩阵和上述组合变换矩阵连乘的结果相同。

代码演示

方法1,(基于opencv,非刚性变换):

#src_pts,np.array,shape[-1,2]
#ref_pts,np.array,shape[-1,2]
tfm = cv2.getAffineTransform(src_pts, ref_pts)        #变换矩阵
tfm_inv = np.linalg.inv(np.vstack([tfm, [0, 0, 1]]))[0:2]   #变换矩阵的逆矩阵
warped_image = cv2.warpAffine(image, tfm , (width, height)) #图片变换

five_points_tfmed = []
for value in five_points:     #点变换five_points[[x1,y1],[x2,y2],[x3,y3],[x4,y4],[x5,y5]]
    tmp = np.ones((3,1),np.float32)
    tmp[0,0] = value[0]
    tmp[1,0] = value[1]
aff_tmp = np.dot(tfm , tmp)
five_points_tfmed.append([aff_tmp[0,0], aff_tmp[1,0]])

方法2,(基于最小二乘法,非刚性变换):

def get_affine_transform_matrix(src_pts, dst_pts):
    tfm = np.float32([[1, 0, 0], [0, 1, 0]])
    n_pts = src_pts.shape[0]
    ones = np.ones((n_pts, 1), src_pts.dtype)
    src_pts_ = np.hstack([src_pts, ones])
    dst_pts_ = np.hstack([dst_pts, ones])

    A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)

    if rank == 3:
        tfm = np.float32([
            [A[0, 0], A[1, 0], A[2, 0]],
            [A[0, 1], A[1, 1], A[2, 1]]
        ])
    elif rank == 2:
        tfm = np.float32([
            [A[0, 0], A[1, 0], 0],
            [A[0, 1], A[1, 1], 0]
        ])

    tfm_inv = np.linalg.inv(np.vstack([tfm, [0, 0, 1]]))[0:2]
    return tfm, tfm_inv

#调用
tfm, tfm_inv = get_affine_transform_matrix(src_pts, ref_pts)

方法3,(刚性变换):

import numpy as np
from numpy.linalg import inv, norm, lstsq
from numpy.linalg import matrix_rank as rank

class MatlabCp2tormException(Exception):
    def __str__(self):
        return 'In File {}:{}'.format(
                __file__, super.__str__(self))

def tformfwd(trans, uv):
    """
    Function:
    ----------
        apply affine transform 'trans' to uv

    Parameters:
    ----------
        @trans: 3x3 np.array
            transform matrix
        @uv: Kx2 np.array
            each row is a pair of coordinates (x, y)

    Returns:
    ----------
        @xy: Kx2 np.array
            each row is a pair of transformed coordinates (x, y)
    """
    uv = np.hstack((
        uv, np.ones((uv.shape[0], 1))
    ))
    xy = np.dot(uv, trans)
    xy = xy[:, 0:-1]
    return xy


def tforminv(trans, uv):
    """
    Function:
    ----------
        apply the inverse of affine transform 'trans' to uv

    Parameters:
    ----------
        @trans: 3x3 np.array
            transform matrix
        @uv: Kx2 np.array
            each row is a pair of coordinates (x, y)

    Returns:
    ----------
        @xy: Kx2 np.array
            each row is a pair of inverse-transformed coordinates (x, y)
    """
    Tinv = inv(trans)
    xy = tformfwd(Tinv, uv)
    return xy


def findNonreflectiveSimilarity(uv, xy, options=None):

    options = {'K': 2}

    K = options['K']
    M = xy.shape[0]
    x = xy[:, 0].reshape((-1, 1))  # use reshape to keep a column vector
    y = xy[:, 1].reshape((-1, 1))  # use reshape to keep a column vector
    # print('--->x, y:\n', x, y

    tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
    tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
    X = np.vstack((tmp1, tmp2))
    # print('--->X.shape: ', X.shape
    # print('X:\n', X

    u = uv[:, 0].reshape((-1, 1))  # use reshape to keep a column vector
    v = uv[:, 1].reshape((-1, 1))  # use reshape to keep a column vector
    U = np.vstack((u, v))
    # print('--->U.shape: ', U.shape
    # print('U:\n', U

    # We know that X * r = U
    if rank(X) >= 2 * K:
        r, _, _, _ = lstsq(X, U)
        r = np.squeeze(r)
    else:
        raise Exception('cp2tform:twoUniquePointsReq')

    # print('--->r:\n', r

    sc = r[0]
    ss = r[1]
    tx = r[2]
    ty = r[3]

    Tinv = np.array([
        [sc, -ss, 0],
        [ss,  sc, 0],
        [tx,  ty, 1]
    ])

    # print('--->Tinv:\n', Tinv

    T = inv(Tinv)
    # print('--->T:\n', T

    T[:, 2] = np.array([0, 0, 1])

    return T, Tinv


def findSimilarity(uv, xy, options=None):

    options = {'K': 2}

#    uv = np.array(uv)
#    xy = np.array(xy)

    # Solve for trans1
    trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)

    # Solve for trans2

    # manually reflect the xy data across the Y-axis
    xyR = xy
    xyR[:, 0] = -1 * xyR[:, 0]

    trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)

    # manually reflect the tform to undo the reflection done on xyR
    TreflectY = np.array([
        [-1, 0, 0],
        [0, 1, 0],
        [0, 0, 1]
    ])

    trans2 = np.dot(trans2r, TreflectY)

    # Figure out if trans1 or trans2 is better
    xy1 = tformfwd(trans1, uv)
    norm1 = norm(xy1 - xy)

    xy2 = tformfwd(trans2, uv)
    norm2 = norm(xy2 - xy)

    if norm1 <= norm2:
        return trans1, trans1_inv
    else:
        trans2_inv = inv(trans2)
        return trans2, trans2_inv


def get_similarity_transform(src_pts, dst_pts, reflective=True):
    """
    Function:
    ----------
        Find Similarity Transform Matrix 'trans':
            u = src_pts[:, 0]
            v = src_pts[:, 1]
            x = dst_pts[:, 0]
            y = dst_pts[:, 1]
            [x, y, 1] = [u, v, 1] * trans

    Parameters:
    ----------
        @src_pts: Kx2 np.array
            source points, each row is a pair of coordinates (x, y)
        @dst_pts: Kx2 np.array
            destination points, each row is a pair of transformed
            coordinates (x, y)
        @reflective: True or False
            if True:
                use reflective similarity transform
            else:
                use non-reflective similarity transform

    Returns:
    ----------
       @trans: 3x3 np.array
            transform matrix from uv to xy
        trans_inv: 3x3 np.array
            inverse of trans, transform matrix from xy to uv
    """

    if reflective:
        trans, trans_inv = findSimilarity(src_pts, dst_pts)
    else:
        trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)

    return trans, trans_inv


def cvt_tform_mat_for_cv2(trans):
    """
    Function:
    ----------
        Convert Transform Matrix 'trans' into 'cv2_trans' which could be
        directly used by cv2.warpAffine():
            u = src_pts[:, 0]
            v = src_pts[:, 1]
            x = dst_pts[:, 0]
            y = dst_pts[:, 1]
            [x, y].T = cv_trans * [u, v, 1].T

    Parameters:
    ----------
        @trans: 3x3 np.array
            transform matrix from uv to xy

    Returns:
    ----------
        @cv2_trans: 2x3 np.array
            transform matrix from src_pts to dst_pts, could be directly used
            for cv2.warpAffine()
    """
    cv2_trans = trans[:, 0:2].T

    return cv2_trans


def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
    """
    Function:
    ----------
        Find Similarity Transform Matrix 'cv2_trans' which could be
        directly used by cv2.warpAffine():
            u = src_pts[:, 0]
            v = src_pts[:, 1]
            x = dst_pts[:, 0]
            y = dst_pts[:, 1]
            [x, y].T = cv_trans * [u, v, 1].T

    Parameters:
    ----------
        @src_pts: Kx2 np.array
            source points, each row is a pair of coordinates (x, y)
        @dst_pts: Kx2 np.array
            destination points, each row is a pair of transformed
            coordinates (x, y)
        reflective: True or False
            if True:
                use reflective similarity transform
            else:
                use non-reflective similarity transform

    Returns:
    ----------
        @cv2_trans: 2x3 np.array
            transform matrix from src_pts to dst_pts, could be directly used
            for cv2.warpAffine()
    """
    trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
    cv2_trans = cvt_tform_mat_for_cv2(trans)
    cv2_trans_inv = cvt_tform_mat_for_cv2(trans_inv)

    return cv2_trans, cv2_trans_inv


if __name__ == '__main__':
    """
    u = [0, 6, -2]
    v = [0, 3, 5]
    x = [-1, 0, 4]
    y = [-1, -10, 4]

    # In Matlab, run:
    #
    #   uv = [u'; v'];
    #   xy = [x'; y'];
    #   tform_sim=cp2tform(uv,xy,'similarity');
    #
    #   trans = tform_sim.tdata.T
    #   ans =
    #       -0.0764   -1.6190         0
    #        1.6190   -0.0764         0
    #       -3.2156    0.0290    1.0000
    #   trans_inv = tform_sim.tdata.Tinv
    #    ans =
    #
    #       -0.0291    0.6163         0
    #       -0.6163   -0.0291         0
    #       -0.0756    1.9826    1.0000
    #    xy_m=tformfwd(tform_sim, u,v)
    #
    #    xy_m =
    #
    #       -3.2156    0.0290
    #        1.1833   -9.9143
    #        5.0323    2.8853
    #    uv_m=tforminv(tform_sim, x,y)
    #
    #    uv_m =
    #
    #        0.5698    1.3953
    #        6.0872    2.2733
    #       -2.6570    4.3314
    """
    u = [0, 6, -2]
    v = [0, 3, 5]
    x = [-1, 0, 4]
    y = [-1, -10, 4]

    uv = np.array((u, v)).T
    xy = np.array((x, y)).T

    print('\n--->uv:')
    print(uv)
    print('\n--->xy:')
    print(xy)

    trans, trans_inv = get_similarity_transform(uv, xy)

    print('\n--->trans matrix:')
    print(trans)

    print('\n--->trans_inv matrix:')
    print(trans_inv)

    print('\n---> apply transform to uv')
    print('\nxy_m = uv_augmented * trans')
    uv_aug = np.hstack((
        uv, np.ones((uv.shape[0], 1))
    ))
    xy_m = np.dot(uv_aug, trans)
    print(xy_m)

    print('\nxy_m = tformfwd(trans, uv)')
    xy_m = tformfwd(trans, uv)
    print(xy_m)

    print('\n---> apply inverse transform to xy')
    print('\nuv_m = xy_augmented * trans_inv')
    xy_aug = np.hstack((
        xy, np.ones((xy.shape[0], 1))
    ))
    uv_m = np.dot(xy_aug, trans_inv)
    print(uv_m)

    print('\nuv_m = tformfwd(trans_inv, xy)')
    uv_m = tformfwd(trans_inv, xy)
    print(uv_m)

    uv_m = tforminv(trans, xy)
    print('\nuv_m = tforminv(trans, xy)')
    print(uv_m)


#调用
tfm, tfm_inv = get_similarity_transform_for_cv2(src_pts, ref_pts)

STN网络结构:

STN网络由Localisation Network ,Grid generator,Sampler,3个部分组成。

Localisation Network:

该网络就是一个简单的回归网络。将输入的图片进行几个卷积操作,然后全连接回归出6个角度值(假设是仿射变换),2*3的矩阵。

Grid generator:

网格生成器负责将V中的坐标位置,通过矩阵运算,计算出目标图V中的每个位置对应原图U中的坐标位置。即生成T(G)。

这里的Grid采样过程,对于二维仿射变换(旋转,平移,缩放)来说,就是简单的矩阵运算。

上式中,s代表原始图的坐标,t代表目标图的坐标。A为Localisation Network网络回归出的6个角度值。

整个Grid生成过程就是,首先你需要想象上图中V-FeatureMap中全是白色或者全是黑色,是没有像素信息的。也就是说V-FeatureMap还不存在,有的只是V-FeatureMap的坐标位置信息。然后将目标图V-FeatureMap中的比如(0,0)(0,1)......位置的坐标,与2*3变换矩阵运算。就会生成出在原始图中对应的坐标信息,比如(5,0)(5,1)......。这样所有的目标图的坐标都经过这样的运算就会将每个坐标都产生一个与之对应的原图的坐标,即T(G)。然后通过T(G)和原始图U-FeatureMap的像素,将原始图中的像素复制到V-FeatureMap中,从而生成目标图的像素。

Sampler:

采样器根据T(G)中的坐标信息,在原始图U中进行采样,将U中的像素复制到目标图V中。

实验结果:

作者分别在MNIST,Street View House Numbers ,CUB-200-2011 birds dataset 这3个数据集上做了实验。

MNIST实验:

R:rotation (旋转)

RTS:rotation, scale and translation (旋转,缩放,平移)

P:projective transformation (投影)

E:elastic warping (弹性变形)

从作图可以看出,FCN 错误率为13.2% , CNN 错误率为3.5% , 与之对比的 ST-FCN 错误率为2.0% ,ST-CNN 错误率为 1.7%。可以看出STN的效果还是非常明显的。

Street View House Numbers实验:

可以看出不管是64像素还是128像素,ST-CNN比传统的CNN错误率要低一些。

CUB-200-2011 birds dataset实验:

右图红色框检测头部,绿色框检测身体。

这个数据集是属于细粒度分类的一个数据集。好多做细粒度分类的文章都会在该数据集上做实验。从这个实验可以看出,STN可以有attention的效果,可以训练的更加关注ROI区域。

实验结果有0.8个百分点的提升。

References:

https://github.com/kevinzakka/spatial-transformer-network

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值