Spatial Transformer Networks并不只是对旋转有效,理论上它应该是对所有的变换都有效,如平移旋转缩放。
CNN目前如果希望网络能具有旋转不变性,主流是两个途径,一直接对数据进行数据增强,即各种旋转,让网络见各种各样的数据,二即采用STN。
STN总体的思路很简单,就是相当于增加一个模块,模块的作用是对输入的图像得到一个变换矩阵,然后输出的图像就等于变换后的图像。
但其实从NN的角度看来,这并不是这样一个完美的变换过程,而是说寻找一个最能让图像像标签的方法。比如输入一张MNIST中的3,但是是转了90度的,在训练中也告诉NN这是3,而其实NN是没见过这种东西的,他见过的大都是正常的正立的3,所以这时在整个网络中,NN会对所有参数进行调节,但可能这时STN模块的调节会更加有效。这其实就相当于一个比较弱的目标探测模块。
从代码可以看出,STN并不简单,里面包含了两个卷积层两个池化,和LeNET5很像。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))