目录
导言
导入
超参数
加载并准备 CIFAR-10 数据集
数据扩增
位置嵌入模块
变压器的 MLP 模块
令牌学习器模块
变换器组
带有 TokenLearner 模块的 ViT 模型
培训实用程序
使用 TokenLearner 培训和评估 ViT
实验结果
参数数量
最终说明
政安晨的个人主页:政安晨
欢迎 ????点赞✍评论⭐收藏
收录专栏: TensorFlow与Keras机器学习实战
希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!
本文目标:为 "视觉变换器 "自适应生成较少数量的令牌。
导言
视觉变换器(Dosovitskiy 等人)和许多其他基于变换器的架构(Liu 等人、Yuan 等人)在图像识别方面取得了显著成果。
下面将简要介绍用于图像分类的视觉变换器架构所涉及的组件:
—— 从输入图像中提取小块图像。
—— 线性投影这些斑块。
—— 为这些线性投影添加位置嵌入。
—— 通过一系列 Transformer(Vaswani 等人)模块运行这些投影。
—— 最后,从最后的 Transformer 模块中提取表示并添加分类头。
如果我们获取 224x224 的图像并提取 16x16 的补丁,那么每张图像总共会得到 196 个补丁(也称为标记)。
随着分辨率的提高,补丁的数量也会增加,从而导致内存占用增加。
我们能否在不影响性能的情况下减少补丁的数量呢?Ryoo 等人在 TokenLearner 中研究了这个问题:视频的自适应时空标记化》中研究了这个问题。他们引入了一个名为 TokenLearner 的新模块,该模块能以自适应的方式帮助减少视觉转换器(ViT)使用的补丁数量。将 TokenLearner 纳入标准 ViT 架构后,他们能够减少模型使用的计算量(以 FLOPS 衡量)。
在本示例中,我们实现了 TokenLearner 模块,并用迷你 ViT 和 CIFAR-10 数据集演示了其性能。
导入
import keras
from keras import layers
from keras import ops
from tensorflow import data as tf_data
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import math
超参数
请随时更改超参数并检查结果。对架构产生直觉的最好方法就是进行实验。
# DATA
BATCH_SIZE = 256
AUTO = tf_data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
NUM_CLASSES = 10
# OPTIMIZER
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
# TRAINING
EPOCHS = 1
# AUGMENTATION
IMAGE_SIZE = 48 # We will resize input images to this size.
PATCH_SIZE = 6 # Size of the patches to be extracted from the input images.
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
# ViT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 4
MLP_UNITS = [
PROJECTION_DIM * 2,
PROJECTION_DIM,
]
# TOKENLEARNER
NUM_TOKENS = 4
加载并准备 CIFAR-10 数据集
# Load the CIFAR-10 dataset.
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
(x_train[:40000], y_train[:40000]),
(x_train[40000:], y_train[40000:]),
)
print(f"Training samples: {len(x_train)}")
print(f"Validation samples: {len(x_val)}")
print(f"Testing samples: {len(x_test)}")
# Convert to tf.data.Dataset objects.
train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)
val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)
test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)
Training samples: 40000
Validation samples: 10000
Testing samples: 10000
数据扩增
扩增管道包括:
重新缩放
调整大小
随机裁剪(固定大小或随机大小)
随机水平翻转
请注意,图像数据增强层在推理时不应用数据转换。这意味着在调用这些层时,如果训练=假,它们的行为会有所不同。
位置嵌入模块
Transformer 架构由多头自我关注层和全连接前馈网络(MLP)作为主要组成部分。这两个组件都具有排列不变性:它们不考虑特征顺序。
为了克服这一问题,我们为标记注入了位置信息。position_embedding 函数将位置信息添加到线性投影的标记中。
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super().__init__()
self.num_patches = num_patches
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = ops.expand_dims(
ops.arange(start=0, stop=self.num_patches, step=1), axis=0
)
encoded = patch + self.position_embedding(positions)
return encoded
def get_config(self):
config = super().get_config()
config.update({"num_patches": self.num_patches})
return config
变换器的 MLP 模块
这是变换器的全连接前馈模块。
def mlp(x, dropout_rate, hidden_units):
# Iterate over the hidden units and
# add Dense => Dropout.
for units in hidden_units:
x = layers.Dense(units, activation=ops.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
令牌学习器模块
下图是该模块的图示概览(来源)。
TokenLearner 模块将图像形状的张量作为输入。然后,它将其通过多个单通道卷积层,提取不同的空间注意力图,并将注意力集中在输入的不同部分。然后,将这些注意力图按元素顺序与输入相乘,并对结果进行汇集。汇集后的输出可以看作是输入的汇总,其补丁数量(例如 8 个)远远少于原始输出(例如 196 个)。
使用多个卷积层有助于提高表现力。施加一种空间注意力有助于保留输入的相关信息。这两个部分对 TokenLearner 的运行都至关重要,尤其是当我们要大幅减少贴片数量时。
def token_learner(inputs, number_of_tokens=NUM_TOKENS):
# Layer normalize the inputs.
x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(inputs) # (B, H, W, C)
# Applying Conv2D => Reshape => Permute
# The reshape and permute is done to help with the next steps of
# multiplication and Global Average Pooling.
attention_maps = keras.Sequential(
[
# 3 layers of conv with gelu activation as suggested
# in the paper.
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
activation=ops.gelu,
padding="same",
use_bias=False,
),
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
activation=ops.gelu,
padding="same",
use_bias=False,
),
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
activation=ops.gelu,
padding="same",
use_bias=False,
),
# This conv layer will generate the attention maps
layers.Conv2D(
filters=number_of_tokens,
kernel_size=(3, 3),
activation="sigmoid", # Note sigmoid for [0, 1] output
padding="same",
use_bias=False,
),
# Reshape and Permute
layers.Reshape((-1, number_of_tokens)), # (B, H*W, num_of_tokens)
layers.Permute((2, 1)),
]
)(
x
) # (B, num_of_tokens, H*W)
# Reshape the input to align it with the output of the conv block.
num_filters = inputs.shape[-1]
inputs = layers.Reshape((1, -1, num_filters))(inputs) # inputs == (B, 1, H*W, C)
# Element-Wise multiplication of the attention maps and the inputs
attended_inputs = (
ops.expand_dims(attention_maps, axis=-1) * inputs
) # (B, num_tokens, H*W, C)
# Global average pooling the element wise multiplication result.
outputs = ops.mean(attended_inputs, axis=2) # (B, num_tokens, C)
return outputs
变换器组
def transformer(encoded_patches):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
# Multi Head Self Attention layer 1.
attention_output = layers.MultiHeadAttention(
num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, encoded_patches])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)
# MLP layer 1.
x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x4, x2])
return encoded_patches
带有 TokenLearner 模块的 ViT 模型
def create_vit_classifier(use_token_learner=True, token_learner_units=NUM_TOKENS):
inputs = layers.Input(shape=INPUT_SHAPE) # (B, H, W, C)
# Augment data.
augmented = data_augmentation(inputs)
# Create patches and project the pathces.
projected_patches = layers.Conv2D(
filters=PROJECTION_DIM,
kernel_size=(PATCH_SIZE, PATCH_SIZE),
strides=(PATCH_SIZE, PATCH_SIZE),
padding="VALID",
)(augmented)
_, h, w, c = projected_patches.shape
projected_patches = layers.Reshape((h * w, c))(
projected_patches
) # (B, number_patches, projection_dim)
# Add positional embeddings to the projected patches.
encoded_patches = PatchEncoder(
num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
)(
projected_patches
) # (B, number_patches, projection_dim)
encoded_patches = layers.Dropout(0.1)(encoded_patches)
# Iterate over the number of layers and stack up blocks of
# Transformer.
for i in range(NUM_LAYERS):
# Add a Transformer block.
encoded_patches = transformer(encoded_patches)
# Add TokenLearner layer in the middle of the
# architecture. The paper suggests that anywhere
# between 1/2 or 3/4 will work well.
if use_token_learner and i == NUM_LAYERS // 2:
_, hh, c = encoded_patches.shape
h = int(math.sqrt(hh))
encoded_patches = layers.Reshape((h, h, c))(
encoded_patches
) # (B, h, h, projection_dim)
encoded_patches = token_learner(
encoded_patches, token_learner_units
) # (B, num_tokens, c)
# Layer normalization and Global average pooling.
representation = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
representation = layers.GlobalAvgPool1D()(representation)
# Classify outputs.
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(representation)
# Create the Keras model.
model = keras.Model(inputs=inputs, outputs=outputs)
return model
如令牌学习器论文所示,将令牌学习器模块置于网络中间几乎总是有利的。
培训实用程序
def run_experiment(model):
# Initialize the AdamW optimizer.
optimizer = keras.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
# Compile the model with the optimizer, loss function
# and the metrics.
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
# Define callbacks
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)
# Train the model.
_ = model.fit(
train_ds,
epochs=EPOCHS,
validation_data=val_ds,
callbacks=[checkpoint_callback],
)
model.load_weights(checkpoint_filepath)
_, accuracy, top_5_accuracy = model.evaluate(test_ds)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
使用 TokenLearner 培训和评估 ViT
vit_token_learner = create_vit_classifier()
run_experiment(vit_token_learner)
157/157 ━━━━━━━━━━━━━━━━━━━━ 303s 2s/step - accuracy: 0.1158 - loss: 2.4798 - top-5-accuracy: 0.5352 - val_accuracy: 0.2206 - val_loss: 2.0292 - val_top-5-accuracy: 0.7688
40/40 ━━━━━━━━━━━━━━━━━━━━ 5s 133ms/step - accuracy: 0.2298 - loss: 2.0179 - top-5-accuracy: 0.7723
Test accuracy: 22.9%
Test top 5 accuracy: 77.22%
实验结果
我们进行了实验,在我们实现的迷你 ViT 中使用和不使用 TokenLearner(超参数与本例中介绍的相同)。以下是我们的结果:
TokenLearner | # tokens in TokenLearner |
Top-1 Acc (Averaged across 5 runs) |
GFLOPs | TensorBoard |
---|---|---|---|---|
N | - | 56.112% | 0.0184 | Link |
Y | 8 | 56.55% | 0.0153 | Link |
N | - | 56.37% | 0.0184 | Link |
Y | 4 | 56.4980% | 0.0147 | Link |
N | - (# Transformer layers: 8) | 55.36% | 0.0359 | Link |
在没有模块的情况下,TokenLearner 的性能始终优于我们的迷你 ViT。同样有趣的是,它也能超越我们的迷你 ViT 的更深版本(有 8 层)。咱们以前也看到类似的观察结果,并将其归功于 TokenLearner 的适应性。
我们还应该注意到,随着令牌学习器模块的加入,FLOPs 数量大大减少。随着 FLOPs 数的减少,TokenLearner 模块能够提供更好的结果。这与作者的研究结果非常吻合。
此外,咱们还为较小的训练数据机制推出了更新版本的 TokenLearner。
该版本不再使用 4 个具有小通道的卷积层来实现空间注意力,而是使用 2 个具有更多通道的分组卷积层。它还使用了 softmax 而不是 sigmoid。我们证实,在训练数据有限的情况下,比如使用 ImageNet1K 从头开始训练时,该版本的效果更好。
我们对该模块进行了实验,并在下表中对实验结果进行了总结:
# Groups | # Tokens | Top-1 Acc | GFLOPs | TensorBoard |
---|---|---|---|---|
4 | 4 | 54.638% | 0.0149 | Link |
8 | 8 | 54.898% | 0.0146 | Link |
4 | 8 | 55.196% | 0.0149 | Link |
请注意,我们在本例中使用了相同的超参数。我们的实现可以在本笔记本中找到。我们承认,使用这个新的 TokenLearner 模块得出的结果比预期的略有偏差,这可能会随着超参数的调整而有所缓解。
(注:为了计算模型的 FLOPs,我们使用了这个软件库中的实用程序。)
参数数量
你可能已经注意到,添加 TokenLearner 模块会增加基础网络的参数数量。但这并不意味着效率会降低,正如 Dehghani 等人的研究所示。贝洛等人也报告了类似的发现。令牌学习器模块有助于减少整个网络的 FLOPS,从而有助于减少内存占用。
最终说明
TokenFuser论文作者还提出了另一个名为 TokenFuser 的模块。
该模块有助于将 TokenLearner 输出的表示重映射回其原始空间分辨率。
要在 ViT 架构中重复使用 TokenLearner,TokenFuser 是必不可少的。
我们首先从令牌学习器中学习令牌,从变换器层建立令牌的表示,然后将表示重映射到原始空间分辨率,这样令牌学习器就能再次使用它。
请注意,在整个 ViT 模型中,如果不与 TokenFuser 配对,只能使用一次 TokenLearner 模块。
在视频中使用这些模块:咱们还建议 TokenFuser 与 Vision Transformers for Videos(阿纳布等人)搭配使用。