MXNet——symbol


參考資料:有基礎(Pytorch/TensorFlow基礎)mxnet+gluon快速入門

symbol

symbol 是一個重要的概念,可以理解為符號,就像我們平時使用的代數符號 xyz 一樣。一個簡單的類比,一個函數 \(f(x) = x^{2}\),符號 x 就是 symbol,而具體 x 的值就是 ndarray,關於 symbol 的是 mxnet.sym,具體可參照官方API文檔

基本操作

  • 使用 mxnet.sym.Variable() 傳入名稱可建立一個 symbol
  • 使用 mxnet.viz.plot_network(symbol=) 傳入 symbol 可以繪制運算圖
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz/bin/'  # 解決 path 錯誤
import mxnet as mx

a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
c = mx.sym.add_n(a,b,name="c")
mx.viz.plot_network(symbol=c)

output_2_0.svg-2.2kB

帶入 ndarray

使用 mxnet.sym.bind() 方法可以獲得一個帶入操作數的對象,再使用 forward() 方法可運算出數值

x = c.bind(ctx=mx.cpu(),args={"a": mx.nd.ones(5),"b":mx.nd.ones(5)})
result = x.forward()
print(result)
[
[2. 2. 2. 2. 2.]
<NDArray 5 @cpu(0)>]

mxnet 的數據載入

深度學習中數據的載入方式非常重要,mxnet 提供了 mxnet.io 的一系列 dataiter 用於處理數據載入,詳細可參照官方API文檔。同時,動態圖接口gluon 也提供了 mxnet.gluon.data 系列的 dataiter 用於數據載入,詳細可參照官方API文檔

mxnet.io 數據載入

mxnet.io的數據載入核心是 mxnet.io.DataIter 類及其派生類,例如 ndarray 的 iter:NDArrayIter

  • 參數 data:傳入一個(名稱-數據)的數據 dict
  • 參數 label:傳入一個(名稱-標簽)的標簽 dict
  • 參數 batch_size:傳入 batch 大小
dataset = mx.io.NDArrayIter(data={'data':mx.nd.ones((10,5))},label={'label':mx.nd.arange(10)},batch_size=5) 
for i in dataset: 
    print(i) 
    print(i.data,type(i.data[0])) 
    print(i.label,type(i.label[0])) 
DataBatch: data shapes: [(5, 5)] label shapes: [(5,)]
[
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
<NDArray 5x5 @cpu(0)>] <class 'mxnet.ndarray.ndarray.NDArray'>
[
[0. 1. 2. 3. 4.]
<NDArray 5 @cpu(0)>] <class 'mxnet.ndarray.ndarray.NDArray'>
DataBatch: data shapes: [(5, 5)] label shapes: [(5,)]
[
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
<NDArray 5x5 @cpu(0)>] <class 'mxnet.ndarray.ndarray.NDArray'>
[
[5. 6. 7. 8. 9.]
<NDArray 5 @cpu(0)>] <class 'mxnet.ndarray.ndarray.NDArray'>

gluon.data 數據載入

gluon 的數據 API 幾乎與 pytorch 相同,均是 Dataset+DataLoader 的方式:

  • Dataset:存儲數據,使用時需要繼承該基類並重載 __len__(self)__getitem__(self,idx) 方法
  • DataLoader:將 Dataset 變成能產生 batch 的可迭代對象
dataset = mx.gluon.data.ArrayDataset(mx.nd.ones((10,5)),mx.nd.arange(10)) 
loader = mx.gluon.data.DataLoader(dataset,batch_size=5) 
for i,data in enumerate(loader): 
    print(i) 
    print(data) 
0
[
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
<NDArray 5x5 @cpu(0)>, 
[0. 1. 2. 3. 4.]
<NDArray 5 @cpu(0)>]
1
[
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
<NDArray 5x5 @cpu(0)>, 
[5. 6. 7. 8. 9.]
<NDArray 5 @cpu(0)>]
class TestSet(mx.gluon.data.Dataset):
    def __init__(self): 
        self.x = mx.nd.zeros((10,5)) 
        self.y = mx.nd.arange(10) 
        
    def __getitem__(self,i): 
        return self.x[i],self.y[i] 
    
    def __len__(self): 
        return 10 
    
    
for i,data in enumerate(mx.gluon.data.DataLoader(TestSet(),batch_size=5)): 
    print(data) 
[
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
<NDArray 5x5 @cpu(0)>, 
[[0.]
 [1.]
 [2.]
 [3.]
 [4.]]
<NDArray 5x1 @cpu(0)>]
[
[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
<NDArray 5x5 @cpu(0)>, 
[[5.]
 [6.]
 [7.]
 [8.]
 [9.]]
<NDArray 5x1 @cpu(0)>]

網絡搭建

mxnet 網絡搭建

mxnet 網絡搭建類似於 TensorFlow,使用 symbol 搭建出網絡,再用一個 module 封裝

data = mx.sym.Variable('data') 
# layer1 
conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=32,name="conv1")
relu1 = mx.sym.Activation(data=conv1,act_type="relu",name="relu1") 
pool1 = mx.sym.Pooling(data=relu1,pool_type="max",kernel=(2,2),stride=(2,2),name="pool1") 
# layer2 
conv2 = mx.sym.Convolution(data=pool1, kernel=(3,3), num_filter=64,name="conv2") 
relu2 = mx.sym.Activation(data=conv2,act_type="relu",name="relu2") 
pool2 = mx.sym.Pooling(data=relu2,pool_type="max",kernel=(2,2),stride=(2,2),name="pool2") 
# layer3 
fc1 = mx.symbol.FullyConnected(data=mx.sym.flatten(pool2), num_hidden=256,name="fc1") 
relu3 = mx.sym.Activation(data=fc1, act_type="relu",name="relu3") 
# layer4 
fc2 = mx.symbol.FullyConnected(data=relu3, num_hidden=10,name="fc2")
out = mx.sym.SoftmaxOutput(data=fc2, label=mx.sym.Variable("label"),name='softmax') 
mxnet_model = mx.mod.Module(symbol=out,label_names=["label"],context=mx.gpu()) 
mx.viz.plot_network(symbol=out) 

output_11_0.svg-10kB

福利:剛剛發現一個解決路徑錯誤的方法:只需要將 *\Anaconda3\Library\bin\graphviz 添加到 Path 環境變量之下即可 (安裝后記得重啟,環境變量修改才可以生效,調用庫,即可成功)!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM