from mxnet.gluon import nn
from mxnet import nd
class SliceLike(nn.HybridBlock):
def __init__(self, xs, **kwargs):
super().__init__(**kwargs)
self.xs = self.params.get_constant('x_', xs)
self.ys = self.params.get('y_', shape=xs.shape)
self.A = 'sl'
def hybrid_forward(self, F, x, xs, ys):
print(self._reg_params)
a = F.slice_like(xs, x * 0, axes=(1))
return a.reshape((1, -1, 4))
hybrid_forward
函数的参数如下形式:(self, F, x, *args, **kwargs)
下面解释一下 (self, F, x, xs, ys)
:首先 self._reg_params
会收集 self.params.get_constant
或者 self.params.get
创建的参数字典,然后直接传入 hybrid_forward 中:
xs = nd.arange(6e4).reshape((10, 10))
sx = SliceLike(xs)
sx.initialize()
y = nd.zeros((1, 1, 2, 3))
sx(y)
{'xs': Constant slicelike12_x_ (shape=(10, 10), dtype=<class 'numpy.float32'>), 'ys': Parameter slicelike12_y_ (shape=(10, 10), dtype=<class 'numpy.float32'>)}
[[[ 0. 10. 20. 30.]
[40. 50. 60. 70.]]]
<NDArray 1x2x4 @cpu(0)>