爆改YOLOv8| 利用ResNet18、34、50、101替换yolo主干网络

时间:2025-04-11 09:00:54
  • from collections import OrderedDict
  • import as nn
  • import as F
  • class ConvNormLayer():
  • def __init__(self,
  • ch_in,
  • ch_out,
  • filter_size,
  • stride,
  • groups=1,
  • act=None):
  • super(ConvNormLayer, self).__init__()
  • self.act = act
  • self.conv = 2d(
  • in_channels=ch_in,
  • out_channels=ch_out,
  • kernel_size=filter_size,
  • stride=stride,
  • padding=(filter_size - 1) // 2,
  • groups=groups)
  • self.norm = 2d(ch_out)
  • def forward(self, inputs):
  • out = self.conv(inputs)
  • out = self.norm(out)
  • if self.act:
  • out = getattr(F, self.act)(out)
  • return out
  • class SELayer():
  • def __init__(self, ch, reduction_ratio=16):
  • super(SELayer, self).__init__()
  • self.avg_pool = 2d(1)
  • self.fc = nn.Sequential(
  • (ch, ch // reduction_ratio, bias=False),
  • (inplace=True),
  • (ch // reduction_ratio, ch, bias=False),
  • ()
  • )
  • def forward(self, x):
  • b, c, _, _ = x.size()
  • y = self.avg_pool(x).view(b, c)
  • y = self.fc(y).view(b, c, 1, 1)
  • return x * y.expand_as(x)
  • class BasicBlock():
  • expansion = 1
  • def __init__(self,
  • ch_in,
  • ch_out,
  • stride,
  • shortcut,
  • act='relu',
  • variant='b',
  • att=False):
  • super(BasicBlock, self).__init__()
  • self.shortcut = shortcut
  • if not shortcut:
  • if variant == 'd' and stride == 2:
  • self.short = nn.Sequential()
  • self.short.add_sublayer(
  • 'pool',
  • 2d(
  • kernel_size=2, stride=2, padding=0, ceil_mode=True))
  • self.short.add_sublayer(
  • 'conv',
  • ConvNormLayer(
  • ch_in=ch_in,
  • ch_out=ch_out,
  • filter_size=1,
  • stride=1))
  • else:
  • self.short = ConvNormLayer(
  • ch_in=ch_in,
  • ch_out=ch_out,
  • filter_size=1,
  • stride=stride)
  • self.branch2a = ConvNormLayer(
  • ch_in=ch_in,
  • ch_out=ch_out,
  • filter_size=3,
  • stride=stride,
  • act='relu')
  • self.branch2b = ConvNormLayer(
  • ch_in=ch_out,
  • ch_out=ch_out,
  • filter_size=3,
  • stride=1,
  • act=None)
  • self.att = att
  • if self.att:
  • self.se = SELayer(ch_out)
  • def forward(self, inputs):
  • out = self.branch2a(inputs)
  • out = self.branch2b(out)
  • if self.att:
  • out = self.se(out)
  • if self.shortcut:
  • short = inputs
  • else:
  • short = self.short(inputs)
  • out = out + short
  • out = (out)
  • return out
  • class BottleNeck():
  • expansion = 4
  • def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='d', att=False):
  • super().__init__()
  • if variant == 'a':
  • stride1, stride2 = stride, 1
  • else:
  • stride1, stride2 = 1, stride
  • width = ch_out
  • self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act)
  • self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act)
  • self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1)
  • self.shortcut = shortcut
  • if not shortcut:
  • if variant == 'd' and stride == 2:
  • self.short = nn.Sequential(OrderedDict([
  • ('pool', 2d(2, 2, 0, ceil_mode=True)),
  • ('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1))
  • ]))
  • else:
  • self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
  • self.att = att
  • if self.att:
  • self.se = SELayer(ch_out)
  • def forward(self, x):
  • out = self.branch2a(x)
  • out = self.branch2b(out)
  • out = self.branch2c(out)
  • if self.att:
  • out = self.se(out)
  • if self.shortcut:
  • short = x
  • else:
  • short = self.short(x)
  • out = out + short
  • out = (out)
  • return out
  • class Blocks():
  • def __init__(self,
  • ch_in,
  • ch_out,
  • count,
  • block,
  • stage_num,
  • att=False,
  • variant='b'):
  • super(Blocks, self).__init__()
  • self.blocks = ()
  • block = globals()[block]
  • for i in range(count):
  • self.(
  • block(
  • ch_in,
  • ch_out,
  • stride=2 if i == 0 and stage_num != 2 else 1,
  • shortcut=False if i == 0 else True,
  • variant=variant,
  • att=att)
  • )
  • if i == 0:
  • ch_in = ch_out * block.expansion
  • def forward(self, inputs):
  • block_out = inputs
  • for block in self.blocks:
  • block_out = block(block_out)
  • return block_out