MXNet 中的 hybird_forward 的一個使用技巧


from mxnet.gluon import nn
from mxnet import nd
class SliceLike(nn.HybridBlock):
    def __init__(self, xs, **kwargs):
        super().__init__(**kwargs)
        self.xs = self.params.get_constant('x_', xs)
        self.ys = self.params.get('y_', shape=xs.shape)
        self.A = 'sl'

    def hybrid_forward(self, F, x, xs, ys):
        print(self._reg_params)
        a = F.slice_like(xs, x * 0, axes=(1))
        return a.reshape((1, -1, 4))

hybrid_forward 函數的參數如下形式:(self, F, x, *args, **kwargs)

下面解釋一下 (self, F, x, xs, ys):首先 self._reg_params 會收集 self.params.get_constant 或者 self.params.get 創建的參數字典,然后直接傳入 hybrid_forward 中:

xs = nd.arange(6e4).reshape((10, 10))
sx = SliceLike(xs)
sx.initialize()
y = nd.zeros((1, 1, 2, 3))
sx(y)
{'xs': Constant slicelike12_x_ (shape=(10, 10), dtype=<class 'numpy.float32'>), 'ys': Parameter slicelike12_y_ (shape=(10, 10), dtype=<class 'numpy.float32'>)}






[[[ 0. 10. 20. 30.]
  [40. 50. 60. 70.]]]
<NDArray 1x2x4 @cpu(0)>


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM