MXNet 中的 hybird_forward 的一个使用技巧


from mxnet.gluon import nn
from mxnet import nd
class SliceLike(nn.HybridBlock):
    def __init__(self, xs, **kwargs):
        super().__init__(**kwargs)
        self.xs = self.params.get_constant('x_', xs)
        self.ys = self.params.get('y_', shape=xs.shape)
        self.A = 'sl'

    def hybrid_forward(self, F, x, xs, ys):
        print(self._reg_params)
        a = F.slice_like(xs, x * 0, axes=(1))
        return a.reshape((1, -1, 4))

hybrid_forward 函数的参数如下形式:(self, F, x, *args, **kwargs)

下面解释一下 (self, F, x, xs, ys):首先 self._reg_params 会收集 self.params.get_constant 或者 self.params.get 创建的参数字典,然后直接传入 hybrid_forward 中:

xs = nd.arange(6e4).reshape((10, 10))
sx = SliceLike(xs)
sx.initialize()
y = nd.zeros((1, 1, 2, 3))
sx(y)
{'xs': Constant slicelike12_x_ (shape=(10, 10), dtype=<class 'numpy.float32'>), 'ys': Parameter slicelike12_y_ (shape=(10, 10), dtype=<class 'numpy.float32'>)}






[[[ 0. 10. 20. 30.]
  [40. 50. 60. 70.]]]
<NDArray 1x2x4 @cpu(0)>


免责声明!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系本站邮箱yoyou2525@163.com删除。



 
粤ICP备18138465号  © 2018-2025 CODEPRJ.COM