Tensor的合並與分割


先來看一下有哪些接口用來進行張量的合並與分割:

tf.concat用來進行張量的拼接,tf.stack用來進行張量的堆疊,tf.split用來進行張量的分割,tf.unstack是tf.split的一種,也用來進行張量分割

1.tf.concat

參數axis代表將要合並的維度

# 假設a代表四個班的成績(每班35人,8個科目),b代表2個班的成績
a = tf.ones([4,35,8])
b = tf.ones([2,35,8])
# 使用concat進行合並得到6個班的成績
c = tf.concat([a,b],axis=0)
# (6,35,8)
print(c.shape)

2.tf.stack(用於創建一個新的維度)

# 假設a代表A學校的四個班的成績(每班35人,8個科目),b代表B學校四個班的成績
a = tf.ones([4,35,8])
b = tf.ones([4,35,8])
# 使用stack進行合並得到6個班的成績
c = tf.stack([a,b],axis=0)
# (2,4,35,8)
print(c.shape)

3.tf.unstack(對某維度進行等分)

# 假設a代表A學校的四個班的成績(每班35人,8個科目),b代表B學校四個班的成績
a = tf.ones([4,35,8])
b = tf.ones([4,35,8])
# 使用stack進行合並得到6個班的成績
c = tf.stack([a,b],axis=0)
# (2,4,35,8)
print(c.shape)
aa,bb=tf.unstack(c,axis=0)
# (4,35,8)
print(aa.shape,bb.shape)
res=tf.unstack(c,axis=3)
# (2,4,35)
print(res[0].shape,res[7].shape)

4.tf.split(按比例打散)

# 假設a代表A學校的四個班的成績(每班35人,8個科目),b代表B學校四個班的成績
a = tf.ones([4,35,8])
b = tf.ones([4,35,8])
# 使用stack進行合並得到6個班的成績
c = tf.stack([a,b],axis=0)
# (2,4,35,8)
print(c.shape)
res = tf.split(c,axis=3,num_or_size_splits=2)
# 2,(2,4,35,4)
print(len(res),res[0].shape,res[1].shape)
res = tf.split(c,axis=3,num_or_size_splits=[2,2,4])
# 3 (2,4,35,2) (2,4,35,2) (2,4,35,4)
print(len(res),res[0].shape,res[1].shape,res[2].shape)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM