pytorch SENet实现案例

时间:2022-04-19 02:23:26

我就废话不多说了,大家还是直接看代码吧~

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from torch import nn
 
class SELayer(nn.Module):
 def __init__(self, channel, reduction=16):
  super(SELayer, self).__init__()
 
  //返回1X1大小的特征图,通道数不变
  self.avg_pool = nn.AdaptiveAvgPool2d(1)
  self.fc = nn.Sequential(
   nn.Linear(channel, channel // reduction, bias=False),
   nn.ReLU(inplace=True),
   nn.Linear(channel // reduction, channel, bias=False),
   nn.Sigmoid()
  )
 
 def forward(self, x):
  b, c, _, _ = x.size()
 
  //全局平均池化,batch和channel和原来一样保持不变
  y = self.avg_pool(x).view(b, c)
 
  //全连接层+池化
  y = self.fc(y).view(b, c, 1, 1)
 
  //和原特征图相乘
  return x * y.expand_as(x)

补充知识:pytorch 实现 SE Block

论文模块图

pytorch SENet实现案例

代码

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch.nn as nn
class SE_Block(nn.Module):
 def __init__(self, ch_in, reduction=16):
  super(SE_Block, self).__init__()
  self.avg_pool = nn.AdaptiveAvgPool2d(1)               # 全局自适应池化
  self.fc = nn.Sequential(
   nn.Linear(ch_in, ch_in // reduction, bias=False),
   nn.ReLU(inplace=True),
   nn.Linear(ch_in // reduction, ch_in, bias=False),
   nn.Sigmoid()
  )
 
 def forward(self, x):
  b, c, _, _ = x.size()
  y = self.avg_pool(x).view(b, c)
  y = self.fc(y).view(b, c, 1, 1)
  return x * y.expand_as(x)

现在还有许多关于SE的变形,但大都大同小异

以上这篇pytorch SENet实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/qq_35985044/article/details/90142431