YOLOv5改进系列(25)——添加LSKNet注意力机制(大选择性卷积核的领域首次探索)

时间:2025-04-02 09:14:56
  • import torch
  • import as nn
  • from import _pair as to_2tuple
  • from .weight_init import (constant_init, normal_init,
  • trunc_normal_init)
  • from ..builder import ROTATED_BACKBONES
  • from import BaseModule
  • from import DropPath, to_2tuple, trunc_normal_
  • import math
  • from functools import partial
  • import warnings
  • from import build_norm_layer
  • # 1. Mlp 模块
  • class Mlp():
  • def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=, drop=0.):
  • super().__init__()
  • out_features = out_features or in_features
  • hidden_features = hidden_features or in_features
  • self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
  • = DWConv(hidden_features)
  • = act_layer()
  • self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
  • = (drop)
  • def forward(self, x):
  • x = self.fc1(x)
  • x = (x)
  • x = (x)
  • x = (x)
  • x = self.fc2(x)
  • x = (x)
  • return x
  • # 2. LSKblock 模块
  • class LSKblock():
  • def __init__(self, dim):
  • super().__init__()
  • self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
  • self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
  • self.conv1 = nn.Conv2d(dim, dim//2, 1)
  • self.conv2 = nn.Conv2d(dim, dim//2, 1)
  • self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
  • = nn.Conv2d(dim//2, dim, 1)
  • def forward(self, x):
  • attn1 = self.conv0(x)
  • attn2 = self.conv_spatial(attn1)
  • attn1 = self.conv1(attn1)
  • attn2 = self.conv2(attn2)
  • attn = ([attn1, attn2], dim=1)
  • avg_attn = (attn, dim=1, keepdim=True)
  • max_attn, _ = torch.max(attn, dim=1, keepdim=True)
  • agg = ([avg_attn, max_attn], dim=1)
  • sig = self.conv_squeeze(agg).sigmoid()
  • attn = attn1 * sig[:,0,:,:].unsqueeze(1) + attn2 * sig[:,1,:,:].unsqueeze(1)
  • attn = (attn)
  • return x * attn
  • # 3. Attention 模块
  • class Attention():
  • def __init__(self, d_model):
  • super().__init__()
  • self.proj_1 = nn.Conv2d(d_model, d_model, 1)
  • = ()
  • self.spatial_gating_unit = LSKblock(d_model)
  • self.proj_2 = nn.Conv2d(d_model, d_model, 1)
  • def forward(self, x):
  • shorcut = ()
  • x = self.proj_1(x)
  • x = (x)
  • x = self.spatial_gating_unit(x)
  • x = self.proj_2(x)
  • x = x + shorcut
  • return x
  • # 4. Block 模块
  • class Block():
  • def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., act_layer=, norm_cfg=None):
  • super().__init__()
  • if norm_cfg:
  • self.norm1 = build_norm_layer(norm_cfg, dim)[1]
  • self.norm2 = build_norm_layer(norm_cfg, dim)[1]
  • else:
  • self.norm1 = nn.BatchNorm2d(dim)
  • self.norm2 = nn.BatchNorm2d(dim)
  • = Attention(dim)
  • self.drop_path = DropPath(drop_path) if drop_path > 0. else ()
  • mlp_hidden_dim = int(dim * mlp_ratio)
  • = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  • layer_scale_init_value = 1e-2
  • self.layer_scale_1 = (
  • layer_scale_init_value * ((dim)), requires_grad=True)
  • self.layer_scale_2 = (
  • layer_scale_init_value * ((dim)), requires_grad=True)
  • def forward(self, x):
  • x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * (self.norm1(x)))
  • x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * (self.norm2(x)))
  • return x
  • # 5. OverlapPatchEmbed 模块
  • class OverlapPatchEmbed():
  • """ Image to Patch Embedding
  • """
  • def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, norm_cfg=None):
  • super().__init__()
  • patch_size = to_2tuple(patch_size)
  • = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
  • padding=(patch_size[0] // 2, patch_size[1] // 2))
  • if norm_cfg:
  • = build_norm_layer(norm_cfg, embed_dim)[1]
  • else:
  • = nn.BatchNorm2d(embed_dim)
  • def forward(self, x):
  • x = (x)
  • _, _, H, W =
  • x = (x)
  • return x, H, W
  • # 6. LSKNet 模块
  • @ROTATED_BACKBONES.register_module()
  • class LSKNet(BaseModule):
  • def __init__(self, img_size=224, in_chans=3, embed_dims=[64, 128, 256, 512],
  • mlp_ratios=[8, 8, 4, 4], drop_rate=0., drop_path_rate=0., norm_layer=partial(, eps=1e-6),
  • depths=[3, 4, 6, 3], num_stages=4,
  • pretrained=None,
  • init_cfg=None,
  • norm_cfg=None):
  • super().__init__(init_cfg=init_cfg)
  • assert not (init_cfg and pretrained), \
  • 'init_cfg and pretrained cannot be set at the same time'
  • if isinstance(pretrained, str):
  • ('DeprecationWarning: pretrained is deprecated, '
  • 'please use "init_cfg" instead')
  • self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
  • elif pretrained is not None:
  • raise TypeError('pretrained must be a str or None')
  • = depths
  • self.num_stages = num_stages
  • dpr = [() for x in (0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  • cur = 0
  • for i in range(num_stages):
  • patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
  • patch_size=7 if i == 0 else 3,
  • stride=4 if i == 0 else 2,
  • in_chans=in_chans if i == 0 else embed_dims[i - 1],
  • embed_dim=embed_dims[i], norm_cfg=norm_cfg)
  • block = ([Block(
  • dim=embed_dims[i], mlp_ratio=mlp_ratios[i], drop=drop_rate, drop_path=dpr[cur + j],norm_cfg=norm_cfg)
  • for j in range(depths[i])])
  • norm = norm_layer(embed_dims[i])
  • cur += depths[i]
  • setattr(self, f"patch_embed{i + 1}", patch_embed)
  • setattr(self, f"block{i + 1}", block)
  • setattr(self, f"norm{i + 1}", norm)
  • def init_weights(self):
  • print('init cfg', self.init_cfg)
  • if self.init_cfg is None:
  • for m in ():
  • if isinstance(m, ):
  • trunc_normal_init(m, std=.02, bias=0.)
  • elif isinstance(m, ):
  • constant_init(m, val=1.0, bias=0.)
  • elif isinstance(m, nn.Conv2d):
  • fan_out = m.kernel_size[0] * m.kernel_size[
  • 1] * m.out_channels
  • fan_out //=
  • normal_init(
  • m, mean=0, std=(2.0 / fan_out), bias=0)
  • else:
  • super(LSKNet, self).init_weights()
  • def freeze_patch_emb(self):
  • self.patch_embed1.requires_grad = False
  • @
  • def no_weight_decay(self):
  • return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
  • def get_classifier(self):
  • return
  • def reset_classifier(self, num_classes, global_pool=''):
  • self.num_classes = num_classes
  • = (self.embed_dim, num_classes) if num_classes > 0 else ()
  • def forward_features(self, x):
  • B = [0]
  • outs = []
  • for i in range(self.num_stages):
  • patch_embed = getattr(self, f"patch_embed{i + 1}")
  • block = getattr(self, f"block{i + 1}")
  • norm = getattr(self, f"norm{i + 1}")
  • x, H, W = patch_embed(x)
  • for blk in block:
  • x = blk(x)
  • x = (2).transpose(1, 2)
  • x = norm(x)
  • x = (B, H, W, -1).permute(0, 3, 1, 2).contiguous()
  • (x)
  • return outs
  • def forward(self, x):
  • x = self.forward_features(x)
  • # x = (x)
  • return x
  • # 7. DWConv 模块
  • class DWConv():
  • def __init__(self, dim=768):
  • super(DWConv, self).__init__()
  • = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
  • def forward(self, x):
  • x = (x)
  • return x
  • # 8. _conv_filter 函数
  • def _conv_filter(state_dict, patch_size=16):
  • """ convert patch embedding weight from manual patchify + linear proj to conv"""
  • out_dict = {}
  • for k, v in state_dict.items():
  • if 'patch_embed.' in k:
  • v = (([0], 3, patch_size, patch_size))
  • out_dict[k] = v
  • return out_dict