先來看一下有哪些接口用來進行張量的合並與分割:
tf.concat用來進行張量的拼接,tf.stack用來進行張量的堆疊,tf.split用來進行張量的分割,tf.unstack是tf.split的一種,也用來進行張量分割
1.tf.concat
參數axis代表將要合並的維度
# 假設a代表四個班的成績(每班35人,8個科目),b代表2個班的成績 a = tf.ones([4,35,8]) b = tf.ones([2,35,8]) # 使用concat進行合並得到6個班的成績 c = tf.concat([a,b],axis=0) # (6,35,8) print(c.shape)
2.tf.stack(用於創建一個新的維度)
# 假設a代表A學校的四個班的成績(每班35人,8個科目),b代表B學校四個班的成績 a = tf.ones([4,35,8]) b = tf.ones([4,35,8]) # 使用stack進行合並得到6個班的成績 c = tf.stack([a,b],axis=0) # (2,4,35,8) print(c.shape)
3.tf.unstack(對某維度進行等分)
# 假設a代表A學校的四個班的成績(每班35人,8個科目),b代表B學校四個班的成績 a = tf.ones([4,35,8]) b = tf.ones([4,35,8]) # 使用stack進行合並得到6個班的成績 c = tf.stack([a,b],axis=0) # (2,4,35,8) print(c.shape) aa,bb=tf.unstack(c,axis=0) # (4,35,8) print(aa.shape,bb.shape) res=tf.unstack(c,axis=3) # (2,4,35) print(res[0].shape,res[7].shape)
4.tf.split(按比例打散)
# 假設a代表A學校的四個班的成績(每班35人,8個科目),b代表B學校四個班的成績 a = tf.ones([4,35,8]) b = tf.ones([4,35,8]) # 使用stack進行合並得到6個班的成績 c = tf.stack([a,b],axis=0) # (2,4,35,8) print(c.shape) res = tf.split(c,axis=3,num_or_size_splits=2) # 2,(2,4,35,4) print(len(res),res[0].shape,res[1].shape) res = tf.split(c,axis=3,num_or_size_splits=[2,2,4]) # 3 (2,4,35,2) (2,4,35,2) (2,4,35,4) print(len(res),res[0].shape,res[1].shape,res[2].shape)