Note:后記
此權值共享非彼卷積共享。說的是layer實體間的參數共享。
Introduction
想將兩幅圖像”同時“經過同一模型,似乎之前有些聽聞的shared model沒有找到確鑿的痕跡,單個構建Variable然后每層設置,對debug階段(甚至使用階段)來說是場噩夢。能夠可行的只想到了,在set_params階段進行指定,如果簡單的將兩個load的symbol進行Group,然后進行bind會提示出現多個名稱。於是問題就是:如何生成同一結構內含指定符號名的symbol?
Exploration
此類非標准操作,更別指望mxnet的doc了,只有從dir()和src查起。
Change the name
首先想到的自然是改名:
本來是
a=mx.sym.Variable('x')
要改成與
a=mx.sym.Variable('y')
相同的效果。
關於名稱的接口:
import mxnet as mx
d=mx.sym.Variable('data')
conv1_w=mx.sym.Variable('kw')
conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
conv1.name
#'conv1'
How to change it
怎么改呢?看起來只有_set_attr靠譜些,先看看都有那些屬性:
conv1.list_attr()
#{'no_bias': 'True', 'kernel': '(3, 3)', 'num_filter': '1'}
。。。並沒有什么好結果出現,看起來還有一個接口:
conv1.attr_dict()
#{'conv1': {'no_bias': 'True', 'kernel': '(3, 3)', 'num_filter': '1'}}
那就試試,'conv1'?
>>>conv1._set_attr(conv1='yy')
>>>conv1.name # 有戲?!趕緊看看
'conv1' # 那剛才的是什么?
>>> conv1.list_attr()
{'no_bias': 'True', 'kernel': '(3, 3)', 'conv1': 'yy', 'num_filter': '1'} # 呵呵,被騙了...
Check the Src
來看看名字是到哪取的(~當然是家里取的...)
# python/mxnet/symbol.py
@property
def name(self):
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.MXSymbolGetName(
self.handle, ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
else:
return None
於是追尋MXSymbolGetName,雖然直覺告訴我很有可能不會有python接口了(很有可能是通過底層實現的名字獲取),但還是得看看。
//src/c_api/c_api_symbolic.cc
int MXSymbolGetName(SymbolHandle symbol,
const char** out,
int* success) {
return NNSymbolGetAttr(symbol, "name", out, success);
}
這不禁讓人浮想起來。。。趕緊試試:
>>> conv1._set_attr(name='yy')
>>> conv1.name
'yy'
被我發現了吧 😃
失敗
失敗的原因是,上述的操作只改變了node,但參數的名稱並沒有改變(可以.list_arguments()進行查看)。我當時想的是將參數名稱保持相同,然后在set_params的時候就可以直接調用,然而實際調用時,會報錯,提示檢測出了多個相同的名稱,所以此路基本封死。
從json入手
這是一個當時認為最慘的辦法——每次都要先對文件進行操作(非常粗野)。但今早發現symbol中還有操作json的接口(當然說的不是save,laod之類的):
sn_epoch_load=0
model_prefix='nin'
sym1, arg_params, aux_params = mx.mod.module.load_checkpoint(model_prefix, n_epoch_load)
sym=sym1.get_internals()['conv4_1024_output'].__copy__()
ss=sym.__getstate__()['handle']
ss1=ss.replace('\"name\": \"','\"name\": \"sha-')
sym2 = sym.__copy__()
h={'handle':ss1}
sym2.__setstate__(h)
>>> sym2.list_arguments()
['sha-data', 'sha-conv1_weight', 'sha-conv1_bias', 'sha-cccp1_weight', 'sha-cccp1_bias', 'sha-cccp2_weight', 'sha-cccp2_bias', 'sha-conv2_weight', 'sha-conv2_bias', 'sha-cccp3_weight', 'sha-cccp3_bias', 'sha-cccp4_weight', 'sha-cccp4_bias', 'sha-conv3_weight', 'sha-conv3_bias', 'sha-cccp5_weight', 'sha-cccp5_bias', 'sha-cccp6_weight', 'sha-cccp6_bias', 'sha-conv4_1024_weight', 'sha-conv4_1024_bias']
>>> sym2.attr_dict()
{'sha-cccp3': {'no_bias': 'False', 'kernel': '(1,1)', 'num_group': '1', 'dilate': '(1,1)', 'num_filter': '256', 'stride': '(1,1)', 'cudnn_off': 'False', 'pad': '(0,0)', 'workspace': '1024', 'cudnn_tune': 'off'}, 'sha-cccp2': {'no_bias': 'False', 'kernel': '(1,1)', 'num_group': '1', 'dilate': '(1,1)', 'num_filter': '96', 'stride': '(1,1)', 'cudnn_off': 'False', 'pad': '(0,0)', 'workspace': '1024', 'cudnn_tune': 'off'}, 'sha-drop': {'p': '0.5'}, 'sha-conv2': {'no_bias': 'False', 'kernel': '(5,5)', 'num_group': '1', 'dilate': '(1,1)', 'num_filter': '256', 'stri
# 示意一下就可
這樣看上去問題被解決了。
Solution
於是我們的答案就是:
import mxnet as mx
M,N=3,3
num_filter=1
kernel=mx.nd.array([ [1,2,3],[1,2,3],[1,2,3] ])
d=mx.sym.Variable('data')
conv1=mx.sym.Convolution(data=d,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
loss=mx.sym.MakeLoss(data=conv1)
bch_kernel=kernel.reshape((1,1,M,N))
arg_params={'conv1_weight': bch_kernel}
def shareParams(sym,params):
sym1 = sym.__copy__()
new_params= {}
ss=sym1.__getstate__()['handle']
ss1=ss.replace('\"name\": \"','\"name\": \"sha-')
h={'handle':ss1}
sym1.__setstate__(h)
for i in params:
new_params['sha-'+i] = params[i]
new_params[i] = params[i]
return mx.sym.Group([sym,sym1]),new_params
sym,params = shareParams(loss,arg_params)
mod=mx.mod.Module(symbol=sym,data_names=('data','sha-data',))
mod.bind(data_shapes=[ ('data',[1,1,M,N]), ('sha-data',[1,1,M,N]),])
mod.init_params()
mod.set_params(arg_params=params, aux_params=[],allow_missing=True)
mod.init_optimizer()
mod.forward(mx.io.DataBatch([bch_kernel,bch_kernel],[]))
mod.get_outputs()[0].asnumpy()
#array([[[[ 42.]]]], dtype=float32)
mod.get_outputs()[1].asnumpy()
#array([[[[ 42.]]]], dtype=float32)
mod.backward()
mod.update()
mod.forward(mx.io.DataBatch([bch_kernel,bch_kernel],[]))
mod.get_outputs()[0].asnumpy()
#array([[[[ 41.57999802]]]], dtype=float32)
mod.get_outputs()[1].asnumpy()
#array([[[[ 41.57999802]]]], dtype=float32)
搞定 😃
22 Jul, 2017 記
關於這個問題,我后面還曾設想找段空閑時期,試着用mxnet內部機制進行封裝。最近發現,自己也是傻得可以。。。
兩張圖先進行batch
維的拼接,通過所需段后再拆分 (⊙﹏⊙)b