关于MSDeformableAttention的理解与代码解读

时间:2024-07-08 07:04:12

Deformable Attention
MMDet-Deformable DETR源码解读
代码解析


MSDeformableAttention 是一个自定义的类,代表“多尺度可变形注意力模块”(Multi-Scale Deformable Attention Module)

  • 多尺度(Multi-Scale): 在深度学习中,多尺度通常指在不同空间分辨率或感受野上处理特征。这有助于模型捕获不同大小的上下文信息。
  • 可变形(Deformable): 在这里,“可变形”可能指的是注意力模块能够自适应地调整其关注的区域或采样点。这与传统的固定形状的卷积核或注意力机制形成对比,后者通常在固定的空间位置进行操作。
  • 注意力模块(Attention Module): 注意力模块是神经网络中的一个组件,它允许模型在处理信息时对不同部分赋予不同的权重。这有助于模型聚焦于输入中最相关或最有信息的部分。
class MSDeformableAttention(nn.Layer):
    def __init__(self,
                 embed_dim=256,
                 num_heads=8,
                 num_levels=4,
                 num_points=4,
                 lr_mult=0.1):
        """
        Multi-Scale Deformable Attention Module
        """
        super(MSDeformableAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_levels = num_levels
        self.num_points = num_points
        self.total_points = num_heads * num_levels * num_points

        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.sampling_offsets = nn.Linear(
            embed_dim,
            self.total_points * 2,
            weight_attr=ParamAttr(learning_rate=lr_mult),
            bias_attr=ParamAttr(learning_rate=lr_mult))

        self.attention_weights = nn.Linear(embed_dim, self.total_points)
        self.value_proj = nn.Linear(embed_dim, embed_dim)
        self.output_proj = nn.Linear(embed_dim, embed_dim)