mxnet下如何查看中間結果


https://blog.csdn.net/disen10/article/details/79376631

固定權重:https://www.cnblogs.com/chenyliang/p/6780019.html

固定權重:https://discuss.gluon.ai/t/topic/1164

查看權重

在訓練過程中,有時候我們為了debug而需要查看中間某一步的權重信息,在mxnet中,我們可以很方便的調用get_params()方法來得到權重信息。

  1.  
    '''
  2.  
    查看權重示例代碼
  3.  
    轉載時注明地址:http://blog.csdn.net/u010414386?viewmode=contents
  4.  
    '''
  5.  
    import mxnet as mx
  6.  
    sym, arg_params, aux_params = mx.model.load_checkpoint( 'resnet-50',0)#載入模型
  7.  
    mod = mx.mod.Module(symbol=sym,context=mx.gpu()) #創建Module
  8.  
    mod.bind(for_training= False,data_shapes=[('data',(1,3,224,224))]) #綁定,此代碼為預測代碼,所以training參數設為False
  9.  
    mod.set_params(arg_params,aux_params)
  10.  
    import numpy as np
  11.  
    import cv2
  12.  
    def get_image(filename):
  13.  
    img = cv2.imread(filename)
  14.  
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
  15.  
    img = cv2.resize(img,( 224,224))
  16.  
    img = np.swapaxes(img, 0,2)
  17.  
    img = np.swapaxes(img, 1,2)
  18.  
    img = img[np.newaxis,:]
  19.  
    return img
  20.  
    from collections import namedtuple
  21.  
    Batch = namedtuple( 'Batch',['data'])
  22.  
    img = get_image( 'val_1000/0.jpg') #獲取圖片
  23.  
    mod.forward(Batch([mx.nd.array(img)])) #預測結果
  24.  
    ################################################
  25.  
    #debug模式下,獲取權重信息
  26.  
    keys = mod.get_params()[ 0].keys() # 列出所有權重名稱
  27.  
    conv_w = mod.get_params()[ 0]['conv0_weight'] #獲取想要查看的權重信息,如conv_weight
  28.  
    print conv_w.asnumpy() #查看具體數值
  29.  
    ################################################
  30.  
    prob = mod.get_outputs()[ 0].asnumpy()
  31.  
    y = np.argsort(np.squeeze(prob))[:: -1]
  32.  
    print( 'truth label %d; top-1 predict label %d' % (val_label[0], y[0]))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

查看中間輸出結果

由於mxnet的網絡由symbol組成,而symbol又屬於符號式編程,所以我們不能像上面查看權重一樣直接查看,我們需要把我們想看的輸出結果保存下來。

  1.  
    '''
  2.  
    方法一
  3.  
    查看中間結果代碼
  4.  
    轉載時注明地址:http://blog.csdn.net/u010414386?viewmode=contents
  5.  
    '''
  6.  
    import mxnet as mx
  7.  
    net = mx.symbol.Variable( 'data')
  8.  
    fc1 = mx.symbol.FullyConnected(data=net, name= 'fc1', num_hidden=128)
  9.  
    net = mx.symbol.Activation(data=fc1, name= 'relu1', act_type="relu")
  10.  
    net = mx.symbol.FullyConnected(data=net, name= 'fc2', num_hidden=64)
  11.  
    out = mx.symbol.SoftmaxOutput(data=net, name= 'softmax')
  12.  
    # 通過把兩個輸出組成一個group來得到自己需要查看的中間層輸出結果
  13.  
    group = mx.symbol.Group([fc1, out])
  14.  
    print group.list_outputs()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  1.  
    '''
  2.  
    方法二
  3.  
    有時候我們使用別人的模型,所以無法像方法一一樣在定義模型的時候就確定需要查看的中間層輸出結果,
  4.  
    這時候我們使用get_internals()方法來查找自己需要查看的中間層
  5.  
    轉載時注明地址:http://blog.csdn.net/u010414386?viewmode=contents
  6.  
    '''
  7.  
    import mxnet as mx
  8.  
    sym, arg_params, aux_params = mx.model.load_checkpoint( 'resnet-50',0)#載入模型
  9.  
    ########################################################################
  10.  
    args = sym.get_internals().list_outputs() #獲得所有中間輸出
  11.  
    internals = model.symbol.get_internals()
  12.  
    fc1 = internals[ 'fc1_output']
  13.  
    conv = internals[ 'stage4_unit3_conv1_output']
  14.  
    group = mx.symbol.Group([fc1, sym, conv]) #把需要輸出的結果按group方式組合起來,這樣就可以得到中間層的輸出
  15.  
    #########################################################################
  16.  
    mod = mx.mod.Module(symbol=group,context=mx.gpu()) #創建Module
  17.  
    mod.bind(for_training= False,data_shapes=[('data',(1,3,224,224))]) #綁定,此代碼為預測代碼,所以training參數設為False
  18.  
    mod.set_params(arg_params,aux_params)
  19.  
    import numpy as np
  20.  
    import cv2
  21.  
    def get_image(filename):
  22.  
    img = cv2.imread(filename)
  23.  
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
  24.  
    img = cv2.resize(img,( 224,224))
  25.  
    img = np.swapaxes(img, 0,2)
  26.  
    img = np.swapaxes(img, 1,2)
  27.  
    img = img[np.newaxis,:]
  28.  
    return img
  29.  
    from collections import namedtuple
  30.  
    Batch = namedtuple( 'Batch',['data'])
  31.  
    img = get_image( 'val_1000/0.jpg') #獲取圖片
  32.  
    mod.forward(Batch([mx.nd.array(img)])) #預測結果
  33.  
    prob = mod.get_outputs()[ 0].asnumpy()
  34.  
    y = np.argsort(np.squeeze(prob))[:: -1]
  35.  
    print( 'truth label %d; top-1 predict label %d' % (val_label[0], y[0]))


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM