mxnet 查看 Sym shape

时间:2024-11-14 21:04:35

import mxnet as mx
import numpy as np
import random
import mxnet as mx
import sys
data_shape = {'data':(60000, 1,28, 28)}
data = mx.sym.var('data')
pool0 = mx.sym.Pooling(data=data, pool_type="max", kernel=(2,2), stride=(2,2),name='pool0')
pool1 = mx.sym.Pooling(data=pool0, pool_type="max", kernel=(2,2), stride=(2,2),name='pool1')
in_shape,out_shape,uax_shape = pool1.infer_shape(**data_shape)
#pool0.list_outputs()
in_shape,out_shape,uax_shape