Mxnet 查看模型params的网络结构

时间:2023-03-08 21:59:49
Mxnet 查看模型params的网络结构
import mxnet as mx
import pdb
def load_checkpoint():
"""
Load model checkpoint from file.
:param prefix: Prefix of model name.
:param epoch: Epoch number of model we would like to load.
:return: (arg_params, aux_params)
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
"""
save_dict = mx.nd.load('model-0000.params')
arg_params = {}
aux_params = {}
for k, v in save_dict.items():
tp, name = k.split(':', 1)
if tp == 'arg':
arg_params[name] = v
if tp == 'aux':
aux_params[name] = v
return arg_params, aux_params def convert_context(params, ctx):
"""
:param params: dict of str to NDArray
:param ctx: the context to convert to
:return: dict of str of NDArray with context ctx
"""
new_params = dict()
for k, v in params.items():
new_params[k] = v.as_in_context(ctx)
#print new_params[0]
return new_params def load_param(convert=False, ctx=None):
"""
wrapper for load checkpoint
:param prefix: Prefix of model name.
:param epoch: Epoch number of model we would like to load.
:param convert: reference model should be converted to GPU NDArray first
:param ctx: if convert then ctx must be designated.
:return: (arg_params, aux_params)
"""
arg_params, aux_params = load_checkpoint()
if convert:
if ctx is None:
ctx = mx.cpu()
arg_params = convert_context(arg_params, ctx)
aux_params = convert_context(aux_params, ctx)
return arg_params, aux_params if __name__=='__main__':
result = load_param();
#pdb.set_trace()
print 'result is'
#print result
for dic in result:
for key in dic:
print(key,dic[key].shape)
# print 'one of results is:'
# print result[0]['fc2_weight'].asnumpy()

python showmxmodel.py 2>&1 | tee log.txt
result is
('stage3_unit2_bn1_beta', (256L,))
('stage3_unit2_bn3_beta', (256L,))
('stage3_unit11_bn1_gamma', (256L,))
('stage3_unit5_bn3_gamma', (256L,))
('stage3_unit3_conv1_weight', (256L, 256L, 3L, 3L))
('stage2_unit1_bn3_gamma', (128L,))
('stage3_unit4_conv1_weight', (256L, 256L, 3L, 3L))
('stage3_unit12_bn3_beta', (256L,))
('stage2_unit2_bn3_beta', (128L,))
('conv0_weight', (64L, 3L, 3L, 3L))
('stage3_unit11_relu1_gamma', (256L,))
('stage4_unit1_conv1sc_weight', (512L, 256L, 1L, 1L))
('stage3_unit1_conv1sc_weight', (256L, 128L, 1L, 1L))
('bn1_beta', (512L,))
('stage1_unit2_bn2_beta', (64L,))
('stage3_unit2_conv2_weight', (256L, 256L, 3L, 3L))
('stage1_unit2_conv1_weight', (64L, 64L, 3L, 3L))
('stage3_unit14_bn2_beta', (256L,))
('stage4_unit2_bn3_beta', (512L,))
('stage3_unit8_bn1_gamma', (256L,))
('stage3_unit7_bn1_gamma', (256L,))
('stage2_unit3_bn1_beta', (128L,))
('stage2_unit4_conv1_weight', (128L, 128L, 3L, 3L))
('stage3_unit2_bn2_gamma', (256L,))
('stage1_unit1_conv1_weight', (64L, 64L, 3L, 3L))
('stage3_unit9_conv2_weight', (256L, 256L, 3L, 3L))
('stage3_unit13_conv1_weight', (256L, 256L, 3L, 3L))
('stage3_unit1_relu1_gamma', (256L,))
('stage4_unit1_bn3_beta', (512L,))
('stage2_unit1_bn2_beta', (128L,))
('stage3_unit14_conv1_weight', (256L, 256L, 3L, 3L))
('stage3_unit8_bn1_beta', (256L,))
('stage3_unit11_conv1_weight', (256L, 256L, 3L, 3L))
('stage1_unit1_bn3_gamma', (64L,))
('stage2_unit2_conv2_weight', (128L, 128L, 3L, 3L))
('stage4_unit2_bn1_gamma', (512L,))
('stage3_unit3_bn1_gamma', (256L,))
('stage1_unit3_bn2_gamma', (64L,))
('stage1_unit3_bn3_gamma', (64L,))
('stage4_unit2_relu1_gamma', (512L,))
('stage3_unit10_conv2_weight', (256L, 256L, 3L, 3L))
('stage3_unit12_conv1_weight', (256L, 256L, 3L, 3L))
('stage3_unit2_relu1_gamma', (256L,))
('stage3_unit10_bn2_beta', (256L,))
('stage2_unit3_bn3_gamma', (128L,))
('stage2_unit3_bn2_beta', (128L,))
('stage3_unit8_bn3_beta', (256L,))
('fc1_gamma', (512L,))
('stage3_unit14_bn3_gamma', (256L,))
('stage3_unit9_bn3_gamma', (256L,))
('stage2_unit3_bn3_beta', (128L,))
('stage3_unit1_sc_gamma', (256L,))
('stage3_unit7_bn1_beta', (256L,))
('stage1_unit2_bn3_beta', (64L,))
('stage3_unit14_relu1_gamma', (256L,))
('stage3_unit13_bn2_beta', (256L,))
('stage2_unit1_conv1sc_weight', (128L, 64L, 1L, 1L))
('bn0_beta', (64L,))
('stage3_unit12_bn1_gamma', (256L,))
('stage2_unit1_sc_gamma', (128L,))
('relu0_gamma', (64L,))
('stage2_unit2_bn2_gamma', (128L,))
('stage3_unit4_relu1_gamma', (256L,))