# iRMB
import math
from functools import partial
from einops import rearrange
from import *
from import DropPath
from .efficientnet_builder import _parse_ksize
from .efficientnet_blocks import num_groups, SqueezeExcite as SE
# ========== 1.LayerNorm2d类:实现对输入张量进行二维的 Layer Normalization 操作 ==========
class LayerNorm2d():
def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
super().__init__()
= (normalized_shape, eps, elementwise_affine)
def forward(self, x):
x = rearrange(x, 'b c h w -> b h w c').contiguous()
x = (x)
x = rearrange(x, 'b h w c -> b c h w').contiguous()
return x
def get_norm(norm_layer='in_1d'):
eps = 1e-6
norm_dict = {
'none': ,
'in_1d': partial(nn.InstanceNorm1d, eps=eps),
'in_2d': partial(nn.InstanceNorm2d, eps=eps),
'in_3d': partial(nn.InstanceNorm3d, eps=eps),
'bn_1d': partial(nn.BatchNorm1d, eps=eps),
'bn_2d': partial(nn.BatchNorm2d, eps=eps),
# 'bn_2d': partial(, eps=eps),
'bn_3d': partial(nn.BatchNorm3d, eps=eps),
'gn': partial(, eps=eps),
'ln_1d': partial(, eps=eps),
'ln_2d': partial(LayerNorm2d, eps=eps),
}
return norm_dict[norm_layer]
def get_act(act_layer='relu'):
act_dict = {
'none': ,
'sigmoid': Sigmoid,
'swish': Swish,
'mish': Mish,
'hsigmoid': HardSigmoid,
'hswish': HardSwish,
'hmish': HardMish,
'tanh': Tanh,
'relu': ,
'relu6': nn.ReLU6,
'prelu': PReLU,
'gelu': GELU,
'silu':
}
return act_dict[act_layer]
# ========== 类:实现卷积、规范化和激活操作的集合 ==========
class ConvNormAct():
def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,
skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):
super(ConvNormAct, self).__init__()
self.has_skip = skip and dim_in == dim_out
padding = ((kernel_size - stride) / 2)
= nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)
= get_norm(norm_layer)(dim_out)
= get_act(act_layer)(inplace=inplace)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else ()
def forward(self, x):
shortcut = x
x = (x)
x = (x)
x = (x)
if self.has_skip:
x = self.drop_path(x) + shortcut
return x
# ========== 类:反向残差注意力机制 ==========
class iRMB():
def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, norm_layer='bn_2d',
act_layer='relu', v_proj=True, dw_ks=3, stride=1, dilation=1, se_ratio=0.0, dim_head=64, window_size=7,
attn_s=True, qkv_bias=False, attn_drop=0., drop=0., drop_path=0., v_group=False, attn_pre=False,inplace=True):
'''
dim_in: 输入特征的维度。
dim_out: 输出特征的维度。
norm_in: 是否对输入进行标准化。
has_skip: 是否使用跳跃连接。
exp_ratio: 扩展比例。
norm_layer: 标准化层的类型。
act_layer: 激活函数的类型。
v_proj: 是否对V进行投影。
dw_ks: 深度可分离卷积的卷积核大小。
stride: 卷积的步幅。
dilation: 卷积的膨胀率。
se_ratio: SE 模块的比例。
dim_head: 注意力头的维度。
window_size: 窗口大小。
attn_s: 是否使用注意力机制。
qkv_bias: 是否在注意力机制中使用偏置。
attn_drop: 注意力机制中的dropout比例。
drop: 全连接层的dropout比例。
drop_path: DropPath 的比例。
v_group: 是否对 V 进行分组卷积。
attn_pre: 是否将注意力机制应用到输入之前。
inplace: 是否原地执行操作。
'''
super().__init__() # 调用父类的构造函数
= get_norm(norm_layer)(dim_in) if norm_in else () # 条件判断,返回一个标准化层(例如 BatchNorm、LayerNorm 等)或使用空操作
dim_mid = int(dim_in * exp_ratio) # 计算中间维度大小
self.has_skip = (dim_in == dim_out and stride == 1) and has_skip # 条件判断,是否使用跳跃连接
self.attn_s = attn_s # 是否使用空间注意力机制的标志
# 如果使用注意力机制
if self.attn_s:
assert dim_in % dim_head == 0, 'dim should be divisible by num_heads' # 确保输入维度 dim_in 可以被 dim_head 整除
self.dim_head = dim_head # 设置每个头的维度为 dim_head
self.window_size = window_size # 设置窗口大小
self.num_head = dim_in // dim_head # 计算头数 self.num_head
= self.dim_head ** -0.5 # 计算缩放因子 ,用于调节注意力分数
self.attn_pre = attn_pre # 设定是否在注意力机制之前重新排列数据 self.attn_pre
# 创建 QK 卷积层、V 卷积层、注意力机制的 dropout 等
= ConvNormAct(dim_in, int(dim_in * 2), kernel_size=1, bias=qkv_bias, norm_layer='none',
act_layer='none')
= ConvNormAct(dim_in, dim_mid, kernel_size=1, groups=self.num_head if v_group else 1, bias=qkv_bias,
norm_layer='none', act_layer=act_layer, inplace=inplace)
self.attn_drop = (attn_drop)
# 如果不使用注意力机制
else:
# 如果需要进行 V 投影,则创建 V 卷积层;否则使用 () 空操作
if v_proj: # 如果使用V投影
= ConvNormAct(dim_in, dim_mid, kernel_size=1, bias=qkv_bias, norm_layer='none',
act_layer=act_layer, inplace=inplace) # 创建V卷积层
else:
= () # 使用空操作
self.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, dilation=dilation,
groups=dim_mid, norm_layer='bn_2d', act_layer='silu', inplace=inplace) # 创建局部卷积层
= SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else () # 创建空间激励模块或使用空操作
self.proj_drop = (drop)
= ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)
self.drop_path = DropPath(drop_path) if drop_path else ()
def forward(self, x):
shortcut = x # 保存输入的快捷连接
x = (x) # 应用标准化层
# 提取输入 x 的形状信息
B, C, H, W =
if self.attn_s: # 如果使用了注意力机制
# padding
if self.window_size <= 0:
window_size_W, window_size_H = W, H
else:
window_size_W, window_size_H = self.window_size, self.window_size
# 计算填充的大小
pad_l, pad_t = 0, 0
pad_r = (window_size_W - W % window_size_W) % window_size_W
pad_b = (window_size_H - H % window_size_H) % window_size_H
x = (x, (pad_l, pad_r, pad_t, pad_b, 0, 0,)) # 对输入进行填充
n1, n2 = (H + pad_b) // window_size_H, (W + pad_r) // window_size_W
x = rearrange(x, 'b c (h1 n1) (w1 n2) -> (b n1 n2) c h1 w1', n1=n1, n2=n2).contiguous() # 重新排列输入数据
# attention
b, c, h, w =
qk = (x) # 计算查询和键的表示
qk = rearrange(qk, 'b (qk heads dim_head) h w -> qk b heads (h w) dim_head', qk=2, heads=self.num_head,
dim_head=self.dim_head).contiguous() # 重排查询和键的表示
q, k = qk[0], qk[1]
attn_spa = (q @ (-2, -1)) * # 计算空间注意力矩阵
attn_spa = attn_spa.softmax(dim=-1) # 对注意力矩阵进行 softmax
attn_spa = self.attn_drop(attn_spa) # 应用注意力 dropout
if self.attn_pre:
x = rearrange(x, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous() # 重排输入特征
x_spa = attn_spa @ x # 应用注意力矩阵到输入特征
x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', heads=self.num_head, h=h,
w=w).contiguous() # 重排输出特征
x_spa = (x_spa) # 对输出特征应用值的表示
else:
v = (x) # 计算值的表示
v = rearrange(v, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous() # 重排值的表示
x_spa = attn_spa @ v # 应用注意力矩阵到值的表示
x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', heads=self.num_head, h=h,
w=w).contiguous() # 重排输出特征
# unpadding
x = rearrange(x_spa, '(b n1 n2) c h1 w1 -> b c (h1 n1) (w1 n2)', n1=n1, n2=n2).contiguous() # 重新排列输出特征
if pad_r > 0 or pad_b > 0:
x = x[:, :, :H, :W].contiguous() # 移除填充部分
else: # 如果不使用注意力机制
x = (x) # 计算值的表示
# 应用空间激励模块和局部卷积层
x = x + (self.conv_local(x)) if self.has_skip else (self.conv_local(x))
# 应用输出投影的 dropout
x = self.proj_drop(x) # 应用 dropout
x = (x) # 应用输出投影
# 添加快捷连接并应用路径丢弃
x = (shortcut + self.drop_path(x)) if self.has_skip else x # 添加快捷连接并应用路径丢弃
return x # 返回处理后的结果