视觉Transformer经典论文——ViT、DeiT的与原理解读与实现
最近ChatGPT、文心一言等大模型爆火,追究其原理还是绕不开2017年提出的Transformer结构。Transformer算法自从提出后,在各个领域的相关工作还是非常多的,这里分享之前在其他平台的一篇笔记给大家,详细解读CV领域的两个经典Transformer系列工作——ViT和DeiT。
ViT算法综述
论文地址:An Image is Worth 16x16 Words:Transformers for Image Recognition at Scale
之前的算法大都是保持CNN整体结构不变,在CNN中增加attention模块或者使用attention模块替换CNN中的某些部分。ViT算法中,作者提出没有必要总是依赖于CNN,仅仅使用Transformer结构也能够在图像分类任务中表现很好。
受到NLP领域中Transformer成功应用的启发,ViT算法中尝试将标准的Transformer结构直接应用于图像,并对整个图像分类流程进行最少的修改。具体来讲,ViT算法中,会将整幅图像拆分成小图像块,然后把这些小图像块的线性嵌入序列作为Transformer的输入送入网络,然后使用监督学习的方式进行图像分类的训练。ViT算法的整体结构如 图1 所示。
图1:ViT算法结构示意图
该算法在中等规模(例如ImageNet)以及大规模(例如ImageNet-21K、JFT-300M)数据集上进行了实验验证,发现:
- Tranformer相较于CNN结构,缺少一定的平移不变性和局部感知性,因此在数据量不充分时,很难达到同等的效果。具体表现为使用中等规模的ImageNet训练的Tranformer会比ResNet在精度上低几个百分点。
- 当有大量的训练样本时,结果则会发生改变。使用大规模数据集进行预训练后,再使用迁移学习的方式应用到其他数据集上,可以达到或超越当前的SOTA水平。
图2 为大家展示了使用大规模数据集预训练后的 ViT 算法,迁移到其他小规模数据集进行训练,与使用 CNN 结构的SOTA算法精度对比。
图2:ViT模型精度对比
图中前3列为不同尺度的ViT模型,使用不同的大规模数据集进行预训练,并迁移到各个子任务上的结果。第4列为BiT算法基于JFT-300M数据集预训练后,迁移到各个子任务的结果。第5列为2020年提出的半监督算法 Noisy Student 在 ImageNet 和 ImageNet ReaL 数据集上的结果。
说明:
BiT 与 Noisy Student 均为2020年提出的 SOTA 算法。
BiT算法:使用大规模数据集 JFT-300M 对 ResNet 结构进行预训练,其中,作者发现模型越大,预训练效果越好,最终指标最高的为4倍宽、152层深的 R e s N e t 152 × 4 ResNet152 \times 4 ResNet152×4。论文地址:Big Transfer (BiT): General Visual Representation Learning
Noisy Student 算法:使用知识蒸馏的技术,基于 EfficientNet 结构,利用未标签数据,提高训练精度。论文地址:Self-training with Noisy Student improves ImageNet classification
接下来,分别看一下ViT算法的各个组成部分。
图像分块嵌入
考虑到之前课程中学习的,Transformer结构中,输入需要是一个二维的矩阵,矩阵的形状可以表示为 ( N , D ) (N,D) (N,D),其中 N N N 是sequence的长度,而 D D D 是sequence中每个向量的维度。因此,在ViT算法中,首先需要设法将 H × W × C H \times W \times C H×W×C 的三维图像转化为 ( N , D ) (N,D) (N,D) 的二维输入。
ViT中的具体实现方式为:将 H × W × C H \times W \times C H×W×C 的图像,变为一个 N × ( P 2 ∗ C ) N \times (P^2 * C) N×(P2∗C) 的序列。这个序列可以看作是一系列展平的图像块,也就是将图像切分成小块后,再将其展平。该序列中一共包含了 N = H W / P 2 N=HW/P^2 N=HW/P2 个图像块,每个图像块的维度则是 ( P 2 ∗ C ) (P^2*C) (P2∗C)。其中 P P P 是图像块的大小, C C C 是通道数量。经过如上变换,就可以将 N N N 视为sequence的长度了。
但是,此时每个图像块的维度是 ( P 2 ∗ C ) (P^2*C) (P2∗C),而我们实际需要的向量维度是 D D D,因此我们还需要对图像块进行 Embedding。这里 Embedding 的方式非常简单,只需要对每个 ( P 2 ∗ C ) (P^2*C) (P2∗C) 的图像块做一个线性变换,将维度压缩为 D D D 即可。
上述对图像进行分块以及 Embedding 的具体方式如 图3 所示。
图3:图像分块嵌入示意图
具体代码实现如下所示。其中,使用了大小为 P P P 的卷积来代替对每个大小为 P P P 图像块展平后使用全连接进行运算的过程。
# coding=utf-8
# 导入环境
import os
import numpy as np
import cv2
from PIL import Image
import paddle
from paddle.io import Dataset
from paddle.nn import Conv2D, MaxPool2D, Linear, Dropout, BatchNorm, AdaptiveAvgPool2D, AvgPool2D
import paddle.nn.functional as F
import paddle.nn as nn
# 图像分块、Embedding
class PatchEmbed(nn.Layer):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
# 原始大小为int,转为tuple,即:img_size原始输入224,变换后为[224,224]
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
# 图像块的个数
num_patches = (img_size[1] // patch_size[1]) * \
(img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
# kernel_size=块大小,即每个块输出一个值,类似每个块展平后使用相同的全连接层进行处理
# 输入维度为3,输出维度为块向量长度
# 与原文中:分块、展平、全连接降维保持一致
# 输出为[B, C, H, W]
self.proj = nn.Conv2D(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# [B, C, H, W] -> [B, C, H*W] ->[B, H*W, C]
x = self.proj(x).flatten(2).transpose((0, 2, 1))
return x
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
Multi-head Attention
将图像转化为 N × ( P 2 ∗ C ) N \times (P^2 * C) N×(P2∗C) 的序列后,就可以将其输入到 Tranformer 结构中进行特征提取了。 Tranformer 结构中最重要的结构就是 Multi-head Attention,即多头注意力结构,如 图4 所示。
图4:Multi-head Attention 示意图
具有2个head的 Multi-head Attention 结构如 图5 所示。输入 a i a^i ai 经过转移矩阵,并切分生成 q ( i , 1 ) q^{(i,1)} q(i,1)、 q ( i , 2 ) q^{(i,2)} q(i,2)、 k ( i , 1 ) k^{(i,1)} k(i,1)、 k ( i , 2 ) k^{(i,2)} k(i,2)、 v ( i , 1 ) v^{(i,1)} v(i,1)、 v ( i , 2 ) v^{(i,2)} v(i,2),然后 q ( i , 1 ) q^{(i,1)} q(i,1) 与 k ( i , 1 ) k^{(i,1)} k(i,1) 做 attention,得到权重向量 α \alpha α,将 α \alpha α 与 v ( i , 1 ) v^{(i,1)} v(i,1) 进行加权求和,得到最终的 b ( i , 1 ) ( i = 1 , 2 , … , N ) b^{(i,1)}(i=1,2,…,N) b(i,1)(i=1,2,…,N),同理可以得到 b ( i , 2 ) ( i = 1 , 2 , … , N ) b^{(i,2)}(i=1,2,…,N) b(i,2)(i=1,2,…,N)。接着将它们拼接起来,通过一个线性层进行处理,得到最终的结果。
图5:Multi-head Attention结构
其中,使用 q ( i , j ) q^{(i,j)} q(i,j)、 k ( i , j ) k^{(i,j)} k(i,j) 与 v ( i , j ) v^{(i,j)} v(i,j) 计算 b ( i , j ) ( i = 1 , 2 , … , N ) b^{(i,j)}(i=1,2,…,N) b(i,j)(i=1,2,…,N) 的方法是 Scaled Dot-Product Attention。 结构如 图6 所示。首先使用每个 q ( i , j ) q^{(i,j)} q(i,j) 去与 k ( i , j ) k^{(i,j)} k(i,j) 做 attention,这里说的 attention 就是匹配这两个向量有多接近,具体的方式就是计算向量的加权内积,得到 α ( i , j ) \alpha_{(i,j)} α(i,j)。这里的加权内积计算方式如下所示:
α ( 1 , i ) = q 1 ∗ k i / d \alpha_{(1,i)} = q^1 * k^i / \sqrt{d} α(1,i)=q1∗ki/d
其中, d d d 是 q q q 和 k k k 的维度,因为 q ∗ k q*k q∗k 的数值会随着维度的增大而增大,因此除以 d \sqrt{d} d 的值也就相当于归一化的效果。
接下来,把计算得到的 α ( i , j ) \alpha_{(i,j)} α(i,j) 取 softmax 操作,再将其与 v ( i , j ) v^{(i,j)} v(i,j) 相乘。
图6:Scaled Dot-Product Attention
具体代码实现如下所示。
# Multi-head Attention
class Attention(nn.Layer):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# 计算 q,k,v 的转移矩阵
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
# 最终的线性层
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
N, C = x.shape[1:]
# 线性变换
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
self.num_heads)).transpose((2, 0, 3, 1, 4))
# 分割 query key value
q, k, v = qkv[0], qkv[1], qkv[2]
# Scaled Dot-Product Attention
# Matmul + Scale
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
# SoftMax
attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
# Matmul
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
# 线性变换
x = self.proj(x)
x = self.proj_drop(x)
return x
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
多层感知机(MLP)
Tranformer 结构中还有一个重要的结构就是 MLP,即多层感知机,如 图7 所示。
图7:多层感知机
多层感知机由输入层、输出层和至少一层的隐藏层构成。网络中各个隐藏层中神经元可接收相邻前序隐藏层中所有神经元传递而来的信息,经过加工处理后将信息输出给相邻后续隐藏层中所有神经元。在多层感知机中,相邻层所包含的神经元之间通常使用“全连接”方式进行连接。多层感知机可以模拟复杂非线性函数功能,所模拟函数的复杂性取决于网络隐藏层数目和各层中神经元数目。多层感知机的结构如 图8 所示。
图8:多层感知机结构
具体代码实现如下所示。
class Mlp(nn.Layer):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
# 输入层:线性变换
x = self.fc1(x)
# 应用激活函数
x = self.act(x)
# Dropout
x = self.drop(x)
# 输出层:线性变换
x = self.fc2(x)
# Dropout
x = self.drop(x)
return x
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
基础模块
基于上面实现的 Attention、MLP 和 DropPath 模块就可以组合出 Vision Transformer 模型的一个基础模块,如 图9 所示。
图9:Transformer 基础模块
这里使用了DropPath(Stochastic Depth)来代替传统的Dropout结构,DropPath可以理解为一种特殊的 Dropout。其作用是在训练过程中随机丢弃子图层(randomly drop a subset of layers),而在预测时正常使用完整的 Graph。
具体实现如下:
def drop_path(x, drop_prob=0., training=False):
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor)
output = x.divide(keep_prob) * random_tensor
return output
class DropPath(nn.Layer):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
基础模块的具体实现如下:
class Block(nn.Layer):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer='',
epsilon=1e-5):
super().__init__()
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
# Multi-head Self-attention
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
# DropPath
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
# Multi-head Self-attention, Add, LayerNorm
x = x + self.drop_path(self.attn(self.norm1(x)))
# Feed Forward, Add, LayerNorm
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
定义ViT网络
基础模块构建好后,就可以构建完整的ViT网络了。ViT的完整结构如 图10 所示。
图10:ViT网络结构
在实现完整网络结构之前,还需要给大家介绍几个模块:
- Class Token
可以看到,假设我们将原始图像切分成 3 × 3 3 \times 3 3×3 共9个小图像块,最终的输入序列长度却是10,也就是说我们这里人为的增加了一个向量进行输入,我们通常将人为增加的这个向量称为 Class Token。那么这个 Class Token 有什么作用呢?
我们可以想象,如果没有这个向量,也就是将 N = 9 N=9 N=9 个向量输入 Transformer 结构中进行编码,我们最终会得到9个编码向量,可对于图像分类任务而言,我们应该选择哪个输出向量进行后续分类呢?
由于选择9个中的哪个都不合适,所以ViT算法中,提出了一个可学习的嵌入向量 Class Token,将它与9个向量一起输入到 Transformer 结构中,输出10个编码向量,然后用这个 Class Token 进行分类预测即可。
其实这里也可以理解为:ViT 其实只用到了 Transformer 中的 Encoder,而并没有用到 Decoder,而 Class Token 的作用就是寻找其他9个输入向量对应的类别。
- Positional Encoding
按照 Transformer 结构中的位置编码习惯,这个工作也使用了位置编码。不同的是,ViT 中的位置编码没有采用原版 Transformer 中的 s i n c o s sincos sincos 编码,而是直接设置为可学习的 Positional Encoding。对训练好的 Positional Encoding 进行可视化,如 图11 所示。我们可以看到,位置越接近,往往具有更相似的位置编码。此外,出现了行列结构,同一行/列中的 patch 具有相似的位置编码。
图11:Positional Encoding
- MLP Head
得到输出后,ViT中使用了 MLP Head对输出进行分类处理,这里的 MLP Head 由 LayerNorm 和两层全连接层组成,并且采用了 GELU 激活函数。
具体代码如下所示。
首先构建基础模块部分,包括:参数初始化配置、独立的不进行任何操作的网络层。
# 参数初始化配置
trunc_normal_ = nn.initializer.TruncatedNormal(std=.02)
zeros_ = nn.initializer.Constant(value=0.)
ones_ = nn.initializer.Constant(value=1.)
# 将输入 x 由 int 类型转为 tuple 类型
def to_2tuple(x):
return tuple([x] * 2)
# 定义一个什么操作都不进行的网络层
class Identity(nn.Layer):
def __init__(self):
super(Identity, self).__init__()
def forward(self, input):
return input
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
完整代码如下所示。
class VisionTransformer(nn.Layer):
def __init__(self,
img_size=384,
patch_size=16,
in_chans=3,
class_dim=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer='',
epsilon=1e-5,
**args):
super().__init__()
self.class_dim = class_dim
self.num_features = self.embed_dim = embed_dim
# 图片分块和降维,块大小为patch_size,最终块向量维度为768
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
# 分块数量
num_patches = self.patch_embed.num_patches
# 可学习的位置编码
self.pos_embed = self.create_parameter(
shape=(1, num_patches + 1, embed_dim), default_initializer=zeros_)
self.add_parameter("pos_embed", self.pos_embed)
# 人为追加class token,并使用该向量进行分类预测
self.cls_token = self.create_parameter(
shape=(1, 1, embed_dim), default_initializer=zeros_)
self.add_parameter("cls_token", self.cls_token)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = np.linspace(0, drop_path_rate, depth)
# transformer
self.blocks = nn.LayerList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
epsilon=epsilon) for i in range(depth)
])
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
# Classifier head
self.head = nn.Linear(embed_dim,
class_dim) if class_dim > 0 else Identity()
trunc_normal_(self.pos_embed)
trunc_normal_(self.cls_token)
self.apply(self._init_weights)
# 参数初始化
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
# 获取图像特征
def forward_features(self, x):
B = paddle.shape(x)[0]
# 将图片分块,并调整每个块向量的维度
x = self.patch_embed(x)
# 将class token与前面的分块进行拼接
cls_tokens = self.cls_token.expand((B, -1, -1))
x = paddle.concat((cls_tokens, x), axis=1)
# 将编码向量中加入位置编码
x = x + self.pos_embed
x = self.pos_drop(x)
# 堆叠 transformer 结构
for blk in self.blocks:
x = blk(x)
# LayerNorm
x = self.norm(x)
# 提取分类 tokens 的输出
return x[:, 0]
def forward(self, x):
# 获取图像特征
x = self.forward_features(x)
# 图像分类
x = self.head(x)
return x
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
DeiT
DeiT 算法综述
论文地址:Training data-efficient image transformers & distillation through attention
在对 ViT 的介绍中,我们了解到,ViT 算法要想取得一个较好的指标,需要先使用 JFT-300 或者 ImageNet-21K 这样的超大规模数据集进行预训练,然后再迁移到其他中等或较小规模的数据集上。而当不使用像 JFT-300 这样的巨大的数据集时,效果是不如 CNN 模型的,也就反映出 Transformer 结构在 CV 领域的一个局限性。对于大多数的研究者而言,使用如此大规模的数据集意味着需要很昂贵的计算资源,一旦无法获取到这些计算资源,不能使用这么大规模的数据集进行预训练,就无法复现出算法应有的效果。所以,出于这个动机,研究者针对 ViT 算法进行了改进,提出了DeiT。
在 DeiT 中,作者在 ViT 的基础上改进了训练策略,并使用了蒸馏学习的方式,只需要在 ImageNet 上进行训练,就可以得到一个有竞争力的 Transformer 模型,而且在单台计算机上,训练时间不到3天即可。这里先简单看一下 DeiT 在 ImageNet 数据集上与之前 ViT 以及 CNN 算法中的 EfficientNet 的精度对比,如 图13 所示。
图13:DeiT 模型精度
上图中的指标均为在 ImageNet 数据集上进行训练,且在 ImageNet 数据集上评估的结果。其中,Ours(Deit) 为使用与 ViT 完全一致的网络结构,但是改进了训练策略;而 Ours⚗(DeiT⚗) 则是在 DeiT 的基础上继续使用了蒸馏学习的方式进行改进。可以看到,ViT 算法在这种中等规模的数据集上,指标远不如 CNN 网络 EfficientNet,而通过改变训练策略,使用蒸馏学习,网络结构与 ViT 基本一致的 DeiT 性能有了很大的提升,超过了 EfficientNet。
网络优化方法
- 蒸馏学习
蒸馏分为两种,分别是软蒸馏(soft distillation)和硬蒸馏(hard distillation)。软蒸馏就是将学生网络的输出结果与教师网络的 softmax 输出结果取 KL Loss;而硬蒸馏就是将学生网络的输出结果与教师网络的标签取交叉熵损失,公式分别如下所示。在 DeiT 中,分别对网络使用了两种蒸馏策略进行对比实验,最终选择了硬蒸馏方式。
L
g
l
o
b
a
l
S
o
f
t
D
i
s
t
i
l
l
=
(
1
−
λ
)
L
C
E
(
ψ
(
Z
s
)
,
y
)
+
λ
τ
2
K
L
(
ψ
(
Z
s
/
τ
)
,
ψ
(
Z
t
/
τ
)
)
L_{global}^{SoftDistill} = (1-\lambda)L_{CE}(\psi(Z_s),y) + \lambda \tau^2 KL(\psi(Z_s/\tau),\psi(Z_t/\tau))
LglobalSoftDistill=(1−λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ))
L
g
l
o
b
a
l
H
a
r
d
D
i
s
t
i
l
l
=
1
2
L
C
E
(
ψ
(
Z
s
)
,
y
)
+
1
2
L
C
E
(
ψ
(
Z
s
)
,
y
t
)
L_{global}^{HardDistill} = \frac{1}{2}L_{CE}(\psi(Z_s),y) + \frac{1}{2}L_{CE}(\psi(Z_s),y_t)
LglobalHardDistill=21LCE(ψ(Zs),y)+21LCE(ψ(Zs),yt)
上式中, y t = a r g m a x c Z t ( c ) y_t=argmax_cZ_t(c) yt=argmaxcZt(c)。这里的硬标签也可以通过标签平滑技术 (Label smoothing) 转换成软标签,其中真值对应的标签被认为具有 1 − e p s i l o n 1-epsilon 1−epsilon 的概率,剩余的 e p s i l o n epsilon epsilon 则由剩余的类别共享。 e p s i l o n epsilon epsilon 是一个超参数,这里取0.1。
- 网络结构
接下来,看一下 DeiT 算法相较于 ViT 算法,在网络结构上的优化方式,如 图14 所示。
图14:DeiT 网络结构
通过对比 图14 与 图1 中ViT的网络结构可以发现,DeiT 与 ViT 的主要差异在于引入了一个 distillation token,顾名思义,其主要用于网络训练中的蒸馏学习。这个 distillation token 与 class token 很像,其在 self-attention layers 中会跟 class token 以及图像 patch 不断交互。而 distillation token 与 class token 唯一的区别在于,class token 的目标是跟真实的 label 一致,而 distillation token 是要跟蒸馏学习中教师网络预测的 label 一致。这个 distillation token 允许我们的模型从教师网络的输出中学习,就像在常规的蒸馏中一样,同时也作为一种对 class token 的补充。
由于具有了 distillation token,因此在损失函数方面,DeiT 相较于 ViT 也有了一定的改变:
- 在 ViT 算法中,模型的最终输出是使用 class token 进行 softmax 计算,获取预测结果属于各个类别的概率分布。此时,损失函数的计算方式与传统的多分类任务一致,直接使用交叉熵损失即可。
- 在 DeiT 算法中,由于使用了蒸馏学习方法,因此在交叉熵损失的基础上,还需要加上一个蒸馏损失。
论文中进行了实验,发现了一个有趣的现象,class token 和 distillation token 是朝着不同的方向收敛的,在开始时,这两个 token 计算余弦相似度只有0.006。随着网络层数的增加,直到最后一层,这两个 token 的余弦相似度变为了0.93。也就可以认为是相似但是不相同的两个 token。
同时,网络为了验证 distillation token 的确给模型添加了某些有益信息,论文中也进行了实验。作者将 distillation token 替换为一个简单的 class token,然后发现,即便这两个 class token 分别进行独立的随机初始化,它们最终也是会随机地收敛到一个几乎一摸一样的结果(余弦相似度为0.999),同时最终性能没有明显提升。
此时,还有最后一个问题,网络既会输出 class token 的结果,也会输出 distillation token 的结果,在最终预测时,我们取谁来作为最终结果呢?答案是将两者的 softmax 结果进行相加,即可简单地得到算法的最终预测结果。
论文实验
DeiT 论文中,进行了一系列实验来选取最优的训练策略。实验中,不同大小的 DeiT 结构超参数设置如 图15 所示。其中,最大的结构是 DeiT-B,与 ViT-B的结构相同,但是 embedding dimension 调整为 768,head 数量为12,每个 head 对应的 embedding dimensions 则为64。DeiT-S 和 DeiT-Ti 则是两个较小的模型,它们调整了 head 的数量,但是保持每个 head 对应的 embedding dimensions 大小不变。
图15:DeiT 实验参数
- 实验1: CNN 与 Transformer 结构谁更适合做 teacher model?
实验1的结果如 图16 所示。实验中,对比了教师网络使用 DeiT-B 和不同的 RegNetY 时,等到的学生网络预训练性能以及 finetune 性能。其中,右侧第一列为预训练的网络指标,第二列则为使用 384 × 384 384 \times 384 384×384大小的分辨率进行 finetune 后的模型指标。可以看到,使用 CNN 作为教师网络能够取得比使用 Transformer 结构作为教师网络更好的结果。
图16:教师模型选择
- 实验2: 哪种蒸馏策略效果更好?
实验2的结果如 图17 所示。
图17:不同蒸馏策略的结果对比
表中前3行对应的分别是:不使用 distillation token 进行训练时,不使用蒸馏学习、使用软蒸馏以及使用硬蒸馏的性能对比。此时的网络训练方法如 图18 所示,即只是相当于在原来 ViT 的基础上给损失函数增加上蒸馏的部分,其他结构保持不变。可以看到,使用硬蒸馏的方式,网络的性能会明显优于不使用蒸馏或使用软蒸馏方式,此时即便没有使用 distillation token,硬蒸馏也可以达到 83.0% 的 top-1 准确率。
图18:不使用 distillation token 进行训练的3种形式
表中后三行则分别对应了:仅使用 class token、仅使用 distillation token 以及既使用class token 又使用 distillation token的性能对比。此时网络的训练方法如 图19 所示。
此时从结果中可以发现:
- class token 和 distillation token 都可以提供对分类有用的信息。
- 仅使用 distillation token 效果略强于仅使用 class token。
- 既使用class token 又使用 distillation token 效果可以达到最优。
图19:使用不同的 token 进行训练的3种形式
同时,有趣的是,学生模型的最终效果甚至可以超过教师网络。
- 实验3: Transformer 能否学到 CNN 的归纳假设?
论文中没有给出最终结论,但是通过 图20 的表格,我们也可以进行简单分析。表格中的数字反映了不同设置之间的决策不一致性。通过第2行的后三列可以看出,使用 distillation token 分类的 DeiT 与 CNN 的不一致性比用 class token 的更小,而两个都用的 DeiT 居中。而通过第2行的第3列与后三列进行对比,可以发现,蒸馏后的 DeiT 与 CNN 的不一致性比蒸馏前的 DeiT 更小。
图20:蒸馏前的DeiT,CNN teacher和蒸馏后的DeiT之间决策的不一致性
- 实验4: 性能对比
论文中,作者对比了经典的 Transformer 结构以及 CNN 结构的性能,如 图21 所示。可以发现,在参数量相当的情况下,CNN 结构比 Transformer 结构更慢,这其实是因为 Transformer 中大的矩阵乘法比 CNN 结构中的小卷积提供了更多的优化机会。而 EffcientNet-B4 和 DeiT-B⚗ 的速度相似,在3个数据集的性能也比较接近。
图21:不同模型性能的数值比较
- 实验5: 迁移学习的性能对比
作者还对比了不同模型(包括一些CNN模型)使用 ImageNet 数据集进行预训练,然后再迁移到不同任务上的性能,如 图22 所示。
图22:不同模型迁移学习能力对比
- 实验6: 训练策略的对比实验
论文中使用了多种训练策略进行网络训练,性能对比如 图23 所示。众所周知,Transformer 网络的训练需要大量的数据,如果想要在不那么大的数据集上取得比较好的性能,就需要大量的数据增强处理。因此,论文中作者进行了一系列数据增强实验。可以发现,几乎所有评测过的数据增强方法都可以提升算法性能。同时,作者还比较了使用不同优化器以及正则化策略时的算法性能,总的来说,使用 AdamW 比使用 SGD 性能更好,使用 Stochastic Depth 有利于收敛,而 Mixup 和 CutMix 也都可以提高性能。
图23:训练策略的对比实验
网络实现的最终代码如下所示。
# DeiT 结构,继承了 ViT 结构
class DistilledVisionTransformer(VisionTransformer):
def __init__(self,
img_size=384,
patch_size=16,
class_dim=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=False,
norm_layer='',
epsilon=1e-5,
**kwargs):
# ViT 结构
super().__init__(
img_size=img_size,
patch_size=patch_size,
class_dim=class_dim,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
epsilon=epsilon,
**kwargs)
# 由于增加了 distillation token,所以也需要调整位置编码的长度
self.pos_embed = self.create_parameter(
shape=(1, self.patch_embed.num_patches + 2, self.embed_dim),
default_initializer=zeros_)
self.add_parameter("pos_embed", self.pos_embed)
# distillation token
self.dist_token = self.create_parameter(
shape=(1, 1, self.embed_dim), default_initializer=zeros_)
self.add_parameter("cls_token", self.cls_token)
# Classifier head
self.head_dist = nn.Linear(
self.embed_dim,
self.class_dim) if self.class_dim > 0 else Identity()
trunc_normal_(self.dist_token)
trunc_normal_(self.pos_embed)
self.head_dist.apply(self._init_weights)
# 获取图像特征
def forward_features(self, x):
B = paddle.shape(x)[0]
# 将图片分块,并调整每个块向量的维度
x = self.patch_embed(x)
# 将class token、distillation token与前面的分块进行拼接
cls_tokens = self.cls_token.expand((B, -1, -1))
dist_token = self.dist_token.expand((B, -1, -1))
x = paddle.concat((cls_tokens, dist_token, x), axis=1)
# 将编码向量中加入位置编码
x = x + self.pos_embed
x = self.pos_drop(x)
# 堆叠 transformer 结构
for blk in self.blocks:
x = blk(x)
# LayerNorm
x = self.norm(x)
# 提取class token以及distillation token的输出
return x[:, 0], x[:, 1]
def forward(self, x):
# 获取图像特征
x, x_dist = self.forward_features(x)
# 图像分类
x = self.head(x)
x_dist = self.head_dist(x_dist)
# 取 class token以及distillation token 的平均值作为结果
return (x + x_dist) / 2
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72