infer_shape for symbol
形狀推斷是mxnet的一特色,即使撇開這樣做的原因是mxnet強制要求的,其提供的功能也是很helpful的。
infer_shape通常是被封裝起來供其內部使用,但也可以把symbol.infer_shape單獨提出來,作為函數:
import mxnet as mx
d=mx.sym.Variable('data')
conv1_w=mx.sym.Variable('kw')
conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
loss=mx.sym.MakeLoss(data=conv1)
in_shape,out_shape,uax_shape=loss.infer_shape(data=(1,1,30,30),kw=(1,1,3,3)) # 直接寫參數名, 此處 kw 可省略
in_shape,out_shape,uax_shape
# ([(1L, 1L, 30L, 30L), (1L, 1L, 3L, 3L)], [(1L, 1L, 28L, 28L)], [])
for module
另外,上面用的是 symbol,有時需要從打包好的module里面提取symbol(mxnet的doc實在是...AWS摻和進來草根本性也不見提升啊):
import mxnet as mx
d=mx.sym.Variable('data')
conv1_w=mx.sym.Variable('kw')
conv1=mx.sym.Convolution(data=d,weight=conv1_w,kernel=(3,3),num_filter=num_filter,no_bias=True,name='conv1')
loss=mx.sym.MakeLoss(data=conv1)
mod = mx.mod.Moudle(symbol=loss)
get_conv1 = mod.symbol.get_internals()['conv1_output']
get_conv1
#<Symbol conv1>