mxnet symbol reshape用法


mx.symbol.reshape

對於給定輸入的array和其shape,可以返回一個含有新shape的一個copy。shape是整形元組類型,可以包含可選的幾個負數。

一些維度的可選值有:{0, -1, -2, -3, -4}

1. 維度0的作用是復制輸入的該維度到對應輸出:

data=mx.sym.Variable('data')   # 輸入symbol
data=mx.sym.Reshape(data=data, shape=(4,0,2))    # reshape目標
print(data.infer_shape(data=(2,3,4))[1])    # 用輸入形狀推理輸出形狀,infer_shape用法見這里~

輸出: (4,3,2)

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(2,0,0)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(2,3,4)

 

2. 維度-1的作用是利用剩余的維度來推斷該維度,要保持所有維度尺寸一樣:

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(6,1,-1)) print(data.infer_shape(data=(2,3,4))[1])    

輸出:(6,1,4)

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(3,-1,8)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(3,1,8)

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(-1)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(24,)

 

3. 維度-2的作用是拷貝全部或剩余的維度到輸出

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(-2)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(2,3,4)

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(2,-2)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(2,3,4)

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(-2,1,1)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(2,3,4,1,1)

 

4. 維度-3的作用是利用兩個連續維度之積作為對應輸出維度

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(-3,4)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(6,4)

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(-3,-3)) print(data.infer_shape(data=(2,3,4,5))[1])

輸出:(6,20)

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(0,-3)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(2,12)

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(-3,-2)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(6,4)

 

5. 維度-4的作用是將輸入的一個維度划分成后續的兩個維度(可含-1)

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(-4,1,2,-2)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(1,2,3,4)

這個稍難理解解釋一下:輸入是(2,3,4),reshape的目標是(-4,1,2,-2),且-4后續的兩個維度為1和2,即希望將-4對應的維度(對應輸入的2)分解成(1,2)。 此時-2將剩余的(3,4)拉過來就變成了(1,2,3,4)。

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(2,-4,-1,3,-2)) print(data.infer_shape(data=(2,3,4))[1])

輸出:(2,1,3,4)

也可以連續reshape:

 

data=mx.sym.Variable('data') data=mx.sym.Reshape(data=data, shape=(-1, -4, -1, 1, 3, 256, 256)) print(data.infer_shape(data=(16, 8, 3, 256, 256))[1]) data=mx.sym.Reshape(data=data, shape=(-3,-3,-2)) print(data.infer_shape(data=(16, 8, 3, 256, 256))[1])

輸出:16,8,1,3,256,256

輸出:128,3,256,256

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM