PyTorch 实现图像版多头注意力(Multi-Head Attention)和自注意力(Self-Attention)-PyTorch 实现图像输入的自注意力机制(Self-Attention)

时间:2025-04-07 16:50:39

本节介绍一种适用于图像输入 (B, C, H, W) 的自注意力机制实现,适合卷积神经网络与 Transformer 的融合模块,如 Self-Attention ConvNet、BAM、CBAM、ViT 前层等。

自注意力机制(图像维度)代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class ImageSelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(ImageSelfAttention, self).__init__()
        self.in_channels = in_channels
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv   = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))  # 可学习缩放因子

    def forward(self, x):
        # 输入 x: (B, C, H, W)
        B, C, H, W = x.size()

        # 生成 Q, K, V
        proj_query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1)  # (B, N, C//8)
        proj_key   = self.key_conv(x).view(B, -1, H * W)                      # (B, C//8, N)
        proj_value = self.value_conv(x).view(B, -1, H * W)                    # (B, C, N)

        # 注意力矩阵:Q * K^T
        energy = torch.bmm(proj_query, proj_key)         # (B, N, N)
        attention = F.softmax(energy, dim=-1)             # (B, N, N)

        # 加权求和 V
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # (B, C, N)
        out = out.view(B, C, H, W)

        # 残差连接 + 缩放因子
        out = self.gamma * out + x
        return out
        
#测试用例
x = torch.randn(2, 64, 32, 32)  # 输入一张图像:B=2, C=64, H=W=32
self_attn = ImageSelfAttention(in_channels=64)
out = self_attn(x)

print(out.shape)  # 输出形状应为 (2, 64, 32, 32)

• 本模块基于图像 (B, C, H, W) 进行自注意力计算
• 使用卷积进行 Q/K/V 提取,保持局部感知力
• gamma 是可学习缩放因子,用于残差连接控制注意力贡献度


自注意力中**缩放因子(scale factor)的处理,在序列维度(如 ViT)和图片维度(如 Self-Attention Conv)**中有点不一样。下面我们来详细解释一下原因,并对两种写法做一个统一和对比分析

两种缩放因子的区别
  1. 序列维度的缩放因子
scale = head_dim ** 0.5  # 或者 embed_dim ** 0.5
attn = (Q @ K.T) / scale

• 来源:Transformer 原始论文(Attention is All You Need)
• 原因:在高维向量内积中,为了避免 dot product 的结果数值过大导致梯度不稳定,需要除以 sqrt(d_k)
• 使用场景:多头注意力机制,输入是 (B, N, C),应用在 NLP、ViT 等序列结构

  1. 图片维度(C, H, W)的注意力机制中没有缩放,或者使用 softmax 平衡
attn = softmax(Q @ K.T)   # 无 scale,或者手动调节

• 来源:Non-local Net、Self-Attention Conv、BAM 等 CNN + Attention 融合方法
• 原因:Q 和 K 都通过 1x1 conv 压缩成 C//8 或更小的维度,内积的值本身不会太大;同时图像 attention 主要用 softmax 控制权重范围
• 缩放因子的控制通常用 γ(gamma)作为残差通道缩放,不是 QK 内部的数值缩放


???? 如果你觉得这篇整理有帮助,欢迎点赞收藏!