YOLOv5轻量化改进 | backbone | 结合MobileNetV4(包含多个结构和使用方式)

时间:2024-10-28 07:55:50
# -*- coding: utf-8 -*- # @Time : 2024/8/3 10:19 # @Author : sjh # @Site : # @File : # @Comment : #!/usr/bin/python # -*- coding: utf-8 -*- # @Time : 2024/3/24 10:27 # @Author : 'IReverser' # @FileName: # Reference: /jaiwei98/MobileNetV4-pytorch from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from .model_config import MODEL_SPECS def make_divisible( value: float, divisor: int, min_value: Optional[float] = None, round_down_protect: bool = True, ) -> int: """ This function is copied from here "/tensorflow/models/blob/master/official/vision/modeling/layers/nn_layers.py" This is to ensure that all layers have channels that are divisible by 8. Args: value: A `float` of original value. divisor: An `int` of the divisor that need to be checked upon. min_value: A `float` of minimum value threshold. round_down_protect: A `bool` indicating whether round down more than 10% will be allowed. Returns: The adjusted value in `int` that is divisible against divisor. """ if min_value is None: min_value = divisor new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. if round_down_protect and new_value < 0.9 * value: new_value += divisor return int(new_value) def conv2d(in_channels, out_channels, kernel_size=3, stride=1, groups=1, bias=False, norm=True, act=True): conv = nn.Sequential() padding = (kernel_size - 1) // 2 conv.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias, groups=groups)) if norm: conv.append(nn.BatchNorm2d(out_channels)) if act: # (nn.ReLU6()) conv.append(nn.SiLU()) return conv class InvertedResidual(nn.Module): def __init__(self, in_channels, out_channels, stride, expand_ratio, act=False, squeeze_exactation=False): super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2] hidden_dim = int(round(in_channels * expand_ratio)) self.block = nn.Sequential() if expand_ratio != 1: self.block.add_module("exp_1x1", conv2d(in_channels, hidden_dim, kernel_size=3, stride=stride)) if squeeze_exactation: self.block.add_module("conv_3x3", conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, groups=hidden_dim)) self.block.add_module("res_1x1", conv2d(hidden_dim, out_channels, kernel_size=1, stride=1, act=act)) self.use_res_connect = self.stride == 1 and in_channels == out_channels def forward(self, x): if self.use_res_connect: return x + self.block(x) else: return self.block(x) class UniversalInvertedBottleneckBlock(nn.Module): def __init__(self, in_channels, out_channels, start_dw_kernel_size, middle_dw_kernel_size, middle_dw_downsample, stride, expand_ratio): """An inverted bottleneck block with optional depthwises. Referenced from here /tensorflow/models/blob/master/official/vision/modeling/layers/nn_blocks.py """ super(UniversalInvertedBottleneckBlock, self).__init__() # starting depthwise conv self.start_dw_kernel_size = start_dw_kernel_size if self.start_dw_kernel_size: stride_ = stride if not middle_dw_downsample else 1 self._start_dw_ = conv2d(in_channels, in_channels, kernel_size=start_dw_kernel_size, stride=stride_, groups=in_channels, act=False) # expansion with 1x1 convs expand_filters = make_divisible(in_channels * expand_ratio, 8) self._expand_conv = conv2d(in_channels, expand_filters, kernel_size=1) # middle depthwise conv self.middle_dw_kernel_size = middle_dw_kernel_size if self.middle_dw_kernel_size: stride_ = stride if middle_dw_downsample else 1 self._middle_dw = conv2d(expand_filters, expand_filters, kernel_size=middle_dw_kernel_size, stride=stride_, groups=expand_filters) # projection with 1x1 convs self._proj_conv = conv2d(expand_filters, out_channels, kernel_size=1, stride=1, act=False) # expand depthwise conv (not used) # _end_dw_kernel_size = 0 # self._end_dw = conv2d(out_channels, out_channels, kernel_size=_end_dw_kernel_size, stride=stride, groups=in_channels, act=False) def forward(self, x): if self.start_dw_kernel_size: x = self._start_dw_(x) # print("_start_dw_", ) x = self._expand_conv(x) # print("_expand_conv", ) if self.middle_dw_kernel_size: x = self._middle_dw(x) # print("_middle_dw", ) x = self._proj_conv(x) # print("_proj_conv", ) return x class MultiQueryAttentionLayerWithDownSampling(nn.Module): def __init__(self, in_channels, num_heads, key_dim, value_dim, query_h_strides, query_w_strides, kv_strides, dw_kernel_size=3, dropout=0.0): """Multi Query Attention with spatial downsampling. Referenced from here /tensorflow/models/blob/master/official/vision/modeling/layers/nn_blocks.py 3 parameters are introduced for the spatial downsampling: 1. kv_strides: downsampling factor on Key and Values only. 2. query_h_strides: vertical strides on Query only. 3. query_w_strides: horizontal strides on Query only. This is an optimized version. 1. Projections in Attention is explict written out as 1x1 Conv2D. 2. Additional reshapes are introduced to bring a up to 3x speed up. """ super(MultiQueryAttentionLayerWithDownSampling, self).__init__() self.num_heads = num_heads self.key_dim = key_dim self.value_dim = value_dim self.query_h_strides = query_h_strides self.query_w_strides = query_w_strides self.kv_strides = kv_strides self.dw_kernel_size = dw_kernel_size self.dropout = dropout self.head_dim = self.key_dim // num_heads if self.query_h_strides > 1 or self.query_w_strides > 1: self._query_downsampling_norm = nn.BatchNorm2d(in_channels) self._query_proj = conv2d(in_channels, self.num_heads * self.key_dim, 1, 1, norm=False, act=False) if self.kv_strides > 1: self._key_dw_conv = conv2d(in_channels, in_channels, dw_kernel_size, kv_strides, groups=in_channels, norm=True, act=False) self._value_dw_conv = conv2d(in_channels, in_channels, dw_kernel_size, kv_strides, groups=in_channels, norm=True, act=False) self._key_proj = conv2d(in_channels, key_dim, 1, 1, norm=False, act=False) self._value_proj = conv2d(in_channels, key_dim, 1, 1, norm=False, act=False) self._output_proj = conv2d(num_heads * key_dim, in_channels, 1, 1, norm=False, act=False) self.dropout = nn.Dropout(p=dropout) def forward(self, x): bs, seq_len, _, _ = x.size() # print(()) if self.query_h_strides > 1 or self.query_w_strides > 1: q = F.avg_pool2d(self.query_h_strides, self.query_w_strides) q = self._query_downsampling_norm(q) q = self._query_proj(q) else: q = self._query_proj(x) px = q.size(2) q = q.view(bs, self.num_heads, -1, self.key_dim) # [batch_size, num_heads, seq_len, key_dim] if self.kv_strides > 1: k = self._key_dw_conv(x) k = self._key_proj(k) v = self._value_dw_conv(x) v = self._value_proj(v) else: k = self._key_proj(x) v = self._value_proj(x) k = k.view(bs, 1, self.key_dim, -1) # [batch_size, 1, key_dim, seq_length] v = v.view(bs, 1, -1, self.key_dim) # [batch_size, 1, seq_length, key_dim] # calculate attention score # print(, , ) attn_score = torch.matmul(q, k) / (self.head_dim ** 0.5) attn_score = self.dropout(attn_score) attn_score = F.softmax(attn_score, dim=-1) # context = ('bnhm,bmv->bnhv', attn_score, v) # print(attn_score.shape, ) context = torch.matmul(attn_score, v) context = context.view(bs, self.num_heads * self.key_dim, px, px) output = self._output_proj(context) # print() return output class MNV4layerScale(nn.Module): def __init__(self, init_value): """LayerScale as introduced in CaiT: /abs/2103.17239 Referenced from here /tensorflow/models/blob/master/official/vision/modeling/layers/nn_blocks.py As used in MobileNetV4. Attributes: init_value (float): value to initialize the diagonal matrix of LayerScale. """ super(MNV4layerScale, self).__init__() self.init_value = init_value def forward(self, x): gamma = self.init_value * torch.ones(x.size(-1), dtype=x.dtype, device=x.device) return x * gamma class MultiHeadSelfAttentionBlock(nn.Module): def __init__(self, in_channels, num_heads, key_dim, value_dim, query_h_strides, query_w_strides, kv_strides, use_layer_scale, use_multi_query, use_residual=True): super(MultiHeadSelfAttentionBlock, self).__init__() self.query_h_strides = query_h_strides self.query_w_strides = query_w_strides self.kv_strides = kv_strides self.use_layer_scale = use_layer_scale self.use_multi_query = use_multi_query self.use_residual = use_residual self._input_norm = nn.BatchNorm2d(in_channels) if self.use_multi_query: self.multi_query_attention = MultiQueryAttentionLayerWithDownSampling( in_channels, num_heads, key_dim, value_dim, query_h_strides, query_w_strides, kv_strides ) else: self.multi_head_attention = nn.MultiheadAttention(in_channels, num_heads, kdim=key_dim) if use_layer_scale: self.layer_scale_init_value = 1e-5 self.layer_scale = MNV4layerScale(self.layer_scale_init_value) def forward(self, x): # Not using CPE, skipped # input norm shortcut = x x = self._input_norm(x) # multi query if self.use_multi_query: # print(()) x = self.multi_query_attention(x) # print(()) else: x = self.multi_head_attention(x, x) # layer scale if self.use_layer_scale: x = self.layer_scale(x) # use residual if self.use_residual: x = x + shortcut return x def build_blocks(layer_spec): global msha if not layer_spec.get("block_name"): return nn.Sequential() block_names = layer_spec["block_name"] layers = nn.Sequential() if block_names == "convbn": schema_ = ["in_channels", "out_channels", "kernel_size", "stride"] for i in range(layer_spec["num_blocks"]): args = dict(zip(schema_, layer_spec["block_specs"][i])) layers.add_module(f"convbn_{i}", conv2d(**args)) elif block_names == "uib": schema_ = ["in_channels", "out_channels", "start_dw_kernel_size", "middle_dw_kernel_size", "middle_dw_downsample", "stride", "expand_ratio", "msha"] for i in range(layer_spec["num_blocks"]): args = dict(zip(schema_, layer_spec["block_specs"][i])) msha = args.pop("msha") if "msha" in args else 0 layers.add_module(f"uib_{i}", UniversalInvertedBottleneckBlock(**args)) if msha: msha_schema_ = [ "in_channels", "num_heads", "key_dim", "value_dim", "query_h_strides", "query_w_strides", "kv_strides", "use_layer_scale", "use_multi_query", "use_residual" ] args = dict(zip(msha_schema_, [args["out_channels"]] + (msha))) layers.add_module( f"msha_{i}", MultiHeadSelfAttentionBlock(**args) ) elif block_names == "fused_ib": schema_ = ["in_channels", "out_channels", "stride", "expand_ratio", "act"] for i in range(layer_spec["num_blocks"]): args = dict(zip(schema_, layer_spec["block_specs"][i])) layers.add_module(f"fused_ib_{i}", InvertedResidual(**args)) else: raise NotImplementedError return layers class MobileNetV4(nn.Module): def __init__(self, model, num_classes=1000): # MobileNetV4ConvSmall MobileNetV4ConvMedium MobileNetV4ConvLarge # MobileNetV4HybridMedium MobileNetV4HybridLarge """Params to initiate MobilenNetV4 Args: model : support 5 types of models as indicated in "/tensorflow/models/blob/master/official/vision/modeling/backbones/" """ super(MobileNetV4, self).__init__() # print(MODEL_SPECS.keys(), model not in MODEL_SPECS.keys()) assert model in MODEL_SPECS.keys() self.model = model self.num_classes = num_classes self.spec = MODEL_SPECS[self.model] # conv0 self.conv0 = build_blocks(self.spec["conv0"]) # layer1 self.layer1 = build_blocks(self.spec["layer1"]) # layer2 self.layer2 = build_blocks(self.spec["layer2"]) # layer3 self.layer3 = build_blocks(self.spec["layer3"]) # layer4 self.layer4 = build_blocks(self.spec["layer4"]) # layer5 self.layer5 = build_blocks(self.spec["layer5"]) # classify [optional] self.fc = nn.Linear(1280, num_classes) self.channel = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640), is_feat=True)] def forward(self, x, is_feat=True): x0 = self.conv0(x) x1 = self.layer1(x0) x2 = self.layer2(x1) x3 = self.layer3(x2) x4 = self.layer4(x3) x5 = self.layer5(x4) out = F.adaptive_avg_pool2d(x5, 1) out = self.fc(out.flatten(1)) if is_feat: return [x0, x1, x2, x3, x4]#, out else: return out def create_mobilenetv4(model_name: str, num_classes: int = 1000): model = MobileNetV4(model_name, num_classes) return model # MNV4ConvSmall, MNV4ConvMedium, MNV4ConvLarge, MNV4HybridMedium, MNV4HybridLarge if __name__ == '__main__': x = torch.rand((1, 3, 640, 640)) model = create_mobilenetv4(model_name="MNV4ConvSmall") feats, out = model(x, is_feat=True) print("logit: ", out.shape) for index, feat in enumerate(feats): print(f"{index}: ", feat.shape) # from torchsummary import summary # summary(create_mobilenetv4(model_name="MNV4ConvSmall"), (1, 3, 224, 224)) print(sum([i.numel() for i in model.parameters()]) / 1024 / 1024, "MB")