MXNet 中的 hybird_forward 的一个使用技巧

时间:2021-08-01 12:53:57
from mxnet.gluon import nn
from mxnet import nd
class SliceLike(nn.HybridBlock):
def __init__(self, xs, **kwargs):
super().__init__(**kwargs)
self.xs = self.params.get_constant('x_', xs)
self.ys = self.params.get('y_', shape=xs.shape)
self.A = 'sl' def hybrid_forward(self, F, x, xs, ys):
print(self._reg_params)
a = F.slice_like(xs, x * 0, axes=(1))
return a.reshape((1, -1, 4))

hybrid_forward 函数的参数如下形式:(self, F, x, *args, **kwargs)

下面解释一下 (self, F, x, xs, ys):首先 self._reg_params 会收集 self.params.get_constant 或者 self.params.get 创建的参数字典,然后直接传入 hybrid_forward 中:

xs = nd.arange(6e4).reshape((10, 10))
sx = SliceLike(xs)
sx.initialize()
y = nd.zeros((1, 1, 2, 3))
sx(y)
{'xs': Constant slicelike12_x_ (shape=(10, 10), dtype=<class 'numpy.float32'>), 'ys': Parameter slicelike12_y_ (shape=(10, 10), dtype=<class 'numpy.float32'>)}

[[[ 0. 10. 20. 30.]
[40. 50. 60. 70.]]]
<NDArray 1x2x4 @cpu(0)>