『MXNet』第四彈_Gluon自定義層


一、不含參數層

通過繼承Block自定義了一個將輸入減掉均值的層:CenteredLayer類,並將層的計算放在forward函數里,

from mxnet import nd, gluon
from mxnet.gluon import nn

class CenteredLayer(nn.Block):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__(**kwargs)

    def forward(self, x):
        return x - x.mean()

# 直接使用這個層
layer = CenteredLayer()
# layer(nd.array([1, 2, 3, 4, 5]))

# 構建更復雜模型
net = nn.Sequential()
net.add(nn.Dense(128))
net.add(nn.Dense(10))
net.add(CenteredLayer())

# 初始化、運行……
net.initialize()
y = net(nd.random.uniform(shape=(4, 8)))

二、含參數層

注意,本節實現的自定義層不能自動推斷輸入尺寸,需要手動指定

見上節『MXNet』第三彈_Gluon模型參數在自定義層的時候我們常使用Block自帶的ParameterDict類添加成員變量params,如下,

from mxnet import gluon
from mxnet.gluon import nn

class MyDense(nn.Block):
    def __init__(self, units, in_units, **kwargs):
        super(MyDense, self).__init__(**kwargs)
        self.weight = self.params.get('weight', shape=(in_units, units))
        self.bias = self.params.get('bias', shape=(units,))        

    def forward(self, x):
        linear = nd.dot(x, self.weight.data()) + self.bias.data()
        return nd.relu(linear)

# 實際運行
dense = MyDense(5, in_units=10)

 如果不想使用ParameterDict類則需要一下操作

# self.weight = self.params.get('weight', shape=(in_units, units))
self.weight = gluon.Parameter('weight', shape=(in_units, units))
self.params.update({'weight':self.weight})

否則在net.initialize()初始化時是初始化不到ParameterDict外變量的。

 有關這一點詳見下面:

    def __init__(self, conv_arch, dropout_keep_prob, **kwargs):
        super(SSD, self).__init__(**kwargs)
        self.vgg_conv = nn.Sequential()
        self.vgg_conv.add(repeat(*conv_arch[0], pool=False))
        [self.vgg_conv.add(repeat(*conv_arch[i])) for i in range(1, len(conv_arch))]
        # 迭代器對象只能進行單次迭代,所以將之轉化為tuple,否則識別參數處迭代后forward再次迭代直接跳出循環
        # self.vgg_conv = tuple([repeat(*conv_arch[i])
        #                       for i in range(len(conv_arch))])
        # 只能識別實例屬性直接為mx層函數或者mx序列對象的參數,如果使用其他容器,需要將參數收集進參數字典
        # _ = [self.params.update(block.collect_params()) for block in self.vgg_conv]

    def forward(self, x, feat_layers):
        end_points = {'block0': x}
        for (index, block) in enumerate(self.vgg_conv):
            end_points.update({'block{:d}'.format(index+1): block(end_points['block{:d}'.format(index)])})
        return end_points

屬性對象是mxnet的對象時才能默認識別層中的參數,否則需要顯式收集進self.params中。

測試代碼:

if __name__ == '__main__':

    ssd = SSD(conv_arch=((2, 64), (2, 128), (3, 256), (3, 512), (3, 512)),
              dropout_keep_prob=0.5)
    ssd.initialize()
    X = mx.ndarray.random.uniform(shape=(1, 1, 304, 304))
    import pprint as pp
    pp.pprint([x[1].shape for x in ssd(X).items()])

自行驗證即可。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM