以channel Attention Block为例子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
class CAB(nn.Module):
def __init__( self , in_channels, out_channels):
super (CAB, self ).__init__()
self .global_pooling = nn.AdaptiveAvgPool2d(output_size = 1 )
self .conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1 , stride = 1 , padding = 0 )
self .relu = nn.ReLU()
self .conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 1 , stride = 1 , padding = 0 )
self .sigmod = nn.Sigmoid()
def forward( self , x):
x1, x2 = x # high, low
x = torch.cat([x1,x2],dim = 1 )
x = self .global_pooling(x)
x = self .conv1(x)
x = self .relu(x)
x = self .conv2(x)
x = self .sigmod(x)
x2 = x * x2
res = x2 + x1
return res
|
以上这篇pytorch forward两个参数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/weixin_41950276/article/details/89069659