最近再写openpose,它的网络结构是多阶段的网络,所以写网络的时候很想用列表的方式,但是直接使用列表不能将网络中相应的部分放入到cuda中去。
其实这个问题很简单的,使用moduleList就好了。
1 我先是定义了一个函数,用来根据超参数,建立一个基础网络结构
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
stage = [[ 3 , 3 , 3 , 1 , 1 ], [ 7 , 7 , 7 , 7 , 7 , 1 , 1 ]]
branches_cfg = [[[ 128 , 128 , 128 , 512 , 38 ], [ 128 , 128 , 128 , 512 , 19 ]],
[[ 128 , 128 , 128 , 128 , 128 , 128 , 38 ], [ 128 , 128 , 128 , 128 , 128 , 128 , 19 ]]]
# used for add two branches as well as adapt to certain stage
def add_extra(i, branches_cfg, stage):
"""
only add CNN of brancdes S & L in stage Ti at the end of net
:param in_channels:the input channels & out
:param stage: size of filter
:param branches_cfg: channels of image
:return:list of layers
"""
in_channels = i
layers = []
for k in range ( len (stage)):
padding = stage[k] / / 2
conv2d = nn.Conv2d(in_channels, branches_cfg[k], kernel_size = stage[k], padding = padding)
layers + = [conv2d, nn.ReLU(inplace = True )]
in_channels = branches_cfg[k]
return layers
|
2 然后用普通列表装载他们
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
conf_bra_list = []
paf_bra_list = []
# param for branch network
in_channels = 128
for i in range (all_stage):
if i > 0 :
branches = branches_cfg[ 1 ]
conv_sz = stage[ 1 ]
else :
branches = branches_cfg[ 0 ]
conv_sz = stage[ 0 ]
conf_bra_list.append(nn.Sequential( * add_extra(in_channels, branches[ 0 ], conv_sz)))
paf_bra_list.append(nn.Sequential( * add_extra(in_channels, branches[ 1 ], conv_sz)))
in_channels = 185
|
3 再然后,使用moduleList方法,把普通列表专成pytorch下的模块
1
2
3
|
# to list
self .conf_bra = nn.ModuleList(conf_bra_list)
self .paf_bra = nn.ModuleList(paf_bra_list)
|
4 最后,调用就好了
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
out_0 = x
# the base transform
for k in range ( len ( self .vgg)):
out_0 = self .vgg[k](out_0)
# local name space
name = locals ()
confs = []
pafs = []
outs = []
length = len ( self .conf_bra)
for i in range (length):
name[ 'conf_%s' % (i + 1 )] = self .conf_bra[i](name[ 'out_%s' % i])
name[ 'paf_%s' % (i + 1 )] = self .paf_bra[i](name[ 'out_%s' % i])
name[ 'out_%s' % (i + 1 )] = torch.cat([name[ 'conf_%s' % (i + 1 )], name[ 'paf_%s' % (i + 1 )], out_0], 1 )
confs.append( 'conf_%s' % (i + 1 ))
pafs.append( 'paf_%s' % (i + 1 ))
outs.append( 'out_%s' % (i + 1 ))
|
5 顺便装了一下,使用了python局部变量命名空间,name = locals(),其实完全使用普通列表保存变量就好了,高兴就好。
以上这篇对pytorch网络层结构的数组化详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/daniaokuye/article/details/78827436