在我們在MXnet中定義好symbol、寫好dataiter並且准備好data之后,就可以開開心的去訓練了。一般訓練一個網絡有兩種常用的策略,基於model的和基於module的。今天,我想談一談他們的使用。
一、Model
按照老規矩,直接從官方文檔里面拿出來的代碼看一下:
# configure a two layer neuralnetwork
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type='relu')
fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
softmax = mx.symbol.SoftmaxOutput(fc2, name='sm')
# create a model using sklearn-style two-step way
#創建一個model
model = mx.model.FeedForward(
softmax,
num_epoch=num_epoch,
learning_rate=0.01)
#開始訓練
model.fit(X=data_set)
具體的API參照http://mxnet.io/api/python/model.html。
然后呢,model這部分就說完了。。。之所以這么快主要有兩個原因:
1.確實東西不多,一般都是查一查文檔就可以了。
2.model的可定制性不強,一般我們是很少使用的,常用的還是module。
二、Module
Module真的是一個很棒的東西,雖然深入了解后,你會覺得“哇,好厲害,但是感覺沒什么鳥用呢”這種想法。。實際上我就有過,現在回想起來,從代碼的設計和使用的角度來講,Module確實是一個非常好的東西,它可以為我們的網絡計算提高了中級、高級的接口,這樣一來,就可以有很多的個性化配置讓我們自己來做了。
Module有四種狀態:
1.初始化狀態,就是顯存還沒有被分配,基本上啥都沒做的狀態。
2.binded,在把data和label的shape傳到Bind函數里並且執行之后,顯存就分配好了,可以准備好計算能力。
3.參數初始化。就是初始化參數
3.Optimizer installed 。就是傳入SGD,Adam這種optimuzer中去進行訓練
先上一個簡單的代碼:
import mxnet as mx
# construct a simple MLP
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
out = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')
# construct the module
mod = mx.mod.Module(out)
mod.bind(data_shapes=train_dataiter.provide_data,
label_shapes=train_dataiter.provide_label)
mod.init_params()
mod.fit(train_dataiter, eval_data=eval_dataiter,
optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
num_epoch=n_epoch)
分析一下:首先是定義了一個簡單的MLP,symbol的名字就叫做out,然后可以直接用mx.mod.Module來創建一個mod。之后mod.bind的操作是在顯卡上分配所需的顯存,所以我們需要把data_shapehe label_shape傳遞給他,然后初始化網絡的參數,再然后就是mod.fit開始訓練了。這里補充一下。fit這個函數我們已經看見兩次了,實際上它是一個集成的功能,mod.fit()實際上它內部的核心代碼是這樣的:
for epoch in range(begin_epoch, num_epoch):
tic = time.time()
eval_metric.reset()
for nbatch, data_batch in enumerate(train_data):
if monitor is not None:
monitor.tic()
self.forward_backward(data_batch) #網絡進行一次前向傳播和后向傳播
self.update() #更新參數
self.update_metric(eval_metric, data_batch.label) #更新metric
if monitor is not None:
monitor.toc_print()
if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)
正是因為module里面我們可以使用很多intermediate的interface,所以可以做出很多改進,舉個最簡單的例子:如果我們的訓練網絡是大小可變怎么辦? 我們可以實現一個mutumodule,基本上就是,每次data的shape變了的時候,我們就重新bind一下symbol,這樣訓練就可以照常進行了。
總結:實際上學一個框架的關鍵還是使用它,要說訣竅的話也就是多看看源碼和文檔了,我寫這些博客的目的,一是為了記錄一些東西,二是讓后來者少走一些彎路。所以有些東西不會說的很全。。
