- 博主简介:努力学习的22级计算机科学与技术本科生一枚????
- 博主主页: @Yaoyao2024
- 往期回顾: 【机器学习】有监督学习·由浅入深讲解分类算法·Fisher算法讲解
- 每日一言????: 今天不想跑,所以才去跑,这才是长距离者的思维。
——村上春树
本文是对Google DeepMind 团队2015年发表的空间变换网络STN的详细讲解,作为初学者也是参考了很多博客,都在本文末尾给出,感谢前辈们的努力。
空间变换网络(Spatial Transformer Networks,简称STN)是一种深度学习模型,旨在增强网络对几何变换的适应能力。STN是由Max Jaderberg等人在2015年提出的,其核心思想是在传统的卷积神经网络(CNN)中嵌入一个可学习的模块,该模块能够显式地对输入图像进行空间变换,从而使得网络能够对输入图像的几何变形具有更好的适应性。STN的引入使得网络能够自动进行图像的校正,例如旋转、缩放、剪切等,这在很多视觉任务中是非常有用的,如图像识别、目标检测和图像分割等。
一、为什么提出(Why)
-
一个理想中的模型:我们希望鲁棒的图像处理模型具有空间不变性,当目标发生某种转化后,模型依然能给出同样的正确的结果
-
什么是空间不变性:举例来说,如下图所示,假设一个模型能准确把左图中的人物分类为凉宫春日,当这个目标做了放大、旋转、平移后,模型仍然能够正确分类,我们就说这个模型在这个任务上具有尺度不变性,旋转不变性,平移不变性
-
CNN在这方面的能力是不足的:maxpooling的机制给了CNN一点点这样的能力,当目标在池化单元内任意变换的话,激活的值可能是相同的,这就带来了一点点的不变性。但是池化单元一般都很小(一般是2*2),只有在深层的时候特征被处理成很小的feature map的时候这种情况才会发生
-
Spatial Transformer:本文提出的空间变换网络STN(Spatial Transformer Networks)STN可以使模型学习平移、缩放、旋转和更通用的扭曲的不变性。(二维空间变换网络)
二、STN是什么(What)
- STN对feature map(包括输入图像)进行空间变换,输出一张新的图像。
- 我们希望STN对feature map进行变换后能把图像纠正到成理想的图像,然后丢进NN去识别,举例来说,如下图所示,输入模型的图像可能是摆着各种姿势,摆在不同位置的凉宫春日,我们希望STN把它纠正到图像的正*,放大,占满整个屏幕,然后再丢进CNN去识别。
- 这个网络可以作为单独的模块,可以在CNN的任何地方插入(即插即用),所以STN的输入不止是输入图像,可以是CNN中间层的feature map
三、STN是怎么做的(How)
STN可以通过为每个输入样本生成适当的变换来主动对图像(或特征图)进行空间变换。然后在整个特征图上(非局部)执行变换,并且可以包括缩放、裁剪、旋转以及非刚性变形。这使得包含空间变换器的网络不仅可以选择图像中最相关(注意力)的区域,还可以将这些区域转换为规范的预期姿势,以简化后续层中的推理。
如上图所示,STN的输入为 U U U,输出为 V V V,因为输入可能是中间层的feature map,所以画成了立方体(多channel),STN主要分为下述三个步骤
-
定位网络(Localization Network)
:这一部分是STN的核心,其任务是学习输入图像的空间变换参数。定位网络可以是任意的网络结构,它接受输入图像,并输出空间变换所需的参数。这些参数定义了一个变换矩阵,用于调整图像的空间位置。(是一个自己定义的网络,它输入 U U U,输出变化参数 Θ \Theta Θ,这个参数用来映射 U U U和 V V V的坐标关系)。 -
网格生成器(Grid Generator)
:接收定位网络输出的变换参数,并生成一个对应于输出图像的坐标网格。这个坐标网格对应于输入图像中的每一个像素位置。根据 V V V中的坐标点和变化参数 Θ \Theta Θ,计算出 U U U中的坐标点。这里是因为 V V V的大小是自己先定义好的,当然可以得到 V V V的所有坐标点,而填充 V V V 中每个坐标点的像素值的时候,要从 U U U中去取,所以根据 V V V中每个坐标点和变化参数 Θ \Theta Θ进行运算,得到一个坐标。在sampler中就是根据这个坐标去 U U U中找到像素值,这样子来填充 V V V -
Sampler
:要做的是填充 V V V,根据Grid generator得到的一系列坐标和原图 U U U(因为像素值要从 U U U中取)来填充,因为计算出来的坐标可能为小数,要用另外的方法来填充,比如双线性插值。从输入图像中采样像素来产生变换后的输出图像。这一步骤确保了图像的空间变换是可微分的,从而可以通过反向传播算法进行训练。
下面针对每个模块阐述一下
1、Localisation net
这个模块就是输入 U U U,输出一个变换参数 Θ \Theta Θ,那么这个 Θ \Theta Θ具体是指什么呢?
我们知道线性代数里,图像的平移,旋转和缩放都可以用矩阵运算来做
举例来说,如果想放大图像中的目标,可以这么运算,把(x,y)中的像素值填充到(x’,y’)上去,比如把原来(2,2)上的像素点,填充到(4,4)上去。
[
x
′
y
′
]
=
[
2
0
0
2
]
[
x
y
]
+
[
0
0
]
\begin{bmatrix}x^{'}\\y^{'}\end{bmatrix}=\begin{bmatrix}2&0\\0&2\end{bmatrix}\begin{bmatrix}x\\y\end{bmatrix}+\begin{bmatrix}0\\0\end{bmatrix}
[x′y′]=[2002][xy]+[00]
如果想旋转图像中的目标,可以这么运算(可以在极坐标系中推出来,证明放到最后的附录)
[
x
′
y
′
]
=
[
c
o
s
Θ
−
s
i
n
Θ
s
i
n
Θ
c
o
s
Θ
]
[
x
y
]
+
[
0
0
]
\begin{bmatrix}x^{'}\\y^{'}\end{bmatrix}=\begin{bmatrix}cos\Theta&-sin\Theta\\sin\Theta&cos\Theta\end{bmatrix}\begin{bmatrix}x\\y\end{bmatrix}+\begin{bmatrix}0\\0\end{bmatrix}
[x′y′]=[cosΘsinΘ−sinΘcosΘ][xy]+[00]
这些都是属于仿射变换(affine transformation)
[ x ′ y ′ ] = [ a b c d ] [ x y ] + [ e f ] \begin{bmatrix}x^{^{\prime}}\\y^{^{\prime}}\end{bmatrix}=\begin{bmatrix}a&b\\c&d\end{bmatrix}\begin{bmatrix}x\\y\end{bmatrix}+\begin{bmatrix}e\\f\end{bmatrix} [x′y′]=[acbd][xy]+[ef]
在仿射变化中,变化参数就是这6个变量, Θ = { a , b , c , d , e , f } (此 Θ 跟上述旋转变化里的角度 Θ 无关) \Theta=\{a,b,c,d,e,f\}\text{(此}\Theta\text{跟上述旋转变化里的角度}\Theta\text{无关)} Θ={a,b,c,d,e,f}(此Θ跟上述旋转变化里的角度Θ无关)
这6个变量就是用来映射输入图和输出图之间的坐标点的关系的,我们在第二步grid generator就要根据这个变化参数,来获取原图的坐标点。
总结如下:
- 功能:定位网络的主要任务是预测空间变换的参数。根据输入图像,这个网络会输出一组参数,这些参数定义了一个空间变换,可以是平移、旋转、缩放等或者更复杂的仿射变换或者非线性变换。
- 结构:定位网络通常是一个小型的卷积神经网络或全连接网络,其具体结构可以根据任务的复杂度和输入数据的特性来定制。网络的输出大小是固定的,对应于特定变换所需的参数数量。
2、Grid generator
有了第一步的变化参数,这一步是做个矩阵运算,这个运算是 以目标图 V V V的所有坐标点为自变量,以为参数做一个矩阵运算,得到输入图 U U U的坐标点
( x i s y i s ) = Θ ( x i t y i t 1 ) = [ Θ 11 Θ 12 Θ 13 Θ 21 Θ 22 Θ 23 ] ( x i t y i t 1 ) \begin{pmatrix}x_i^s\\y_i^s\end{pmatrix}=\Theta\begin{pmatrix}x_i^t\\y_i^t\\1\end{pmatrix}=\begin{bmatrix}\Theta_{11}&\Theta_{12}&\Theta_{13}\\\Theta_{21}&\Theta_{22}&\Theta_{23}\end{bmatrix}\begin{pmatrix}x_i^t\\y_i^t\\1\end{pmatrix} (xisyis)=Θ xityit1 =[Θ11Θ21Θ12Θ22Θ13Θ23