tf.concat, tf.stack和tf.unstack的用法


tf.concat, tf.stack和tf.unstack的用法

tf.concat相當於numpy中的np.concatenate函數,用於將兩個張量在某一個維度(axis)合並起來,例如:

a = tf.constant([[1,2,3],[3,4,5]]) # shape (2,3) b = tf.constant([[7,8,9],[10,11,12]]) # shape (2,3) ab1 = tf.concat([a,b], axis=0) # shape(4,3) ab2 = tf.concat([a,b], axis=1) # shape(2,6)
  • 1
  • 2
  • 3
  • 4

tf.stack其作用類似於tf.concat,都是拼接兩個張量,而不同之處在於,tf.concat拼接的是兩個shape完全相同的張量,並且產生的張量的階數不會發生變化,而tf.stack則會在新的張量階上拼接,產生的張量的階數將會增加,例如:

a = tf.constant([[1,2,3],[3,4,5]]) # shape (2,3) b = tf.constant([[7,8,9],[10,11,12]]) # shape (2,3) ab = tf.stack([a,b], axis=0) # shape (2,2,3)
  • 1
  • 2
  • 3

改變參數axis為2,有:

import tensorflow as tf a = tf.constant([[1,2,3],[3,4,5]]) # shape (2,3) b = tf.constant([[7,8,9],[10,11,12]]) # shape (2,3) ab = tf.stack([a,b], axis=2) # shape (2,3,2)

 

 

所以axis是決定其層疊(stack)張量的維度方向的。

tf.unstacktf.stack的操作相反,是將一個高階數的張量在某個axis上分解為低階數的張量,例如:

a = tf.constant([[1,2,3],[3,4,5]]) # shape (2,3) b = tf.constant([[7,8,9],[10,11,12]]) # shape (2,3) ab = tf.stack([a,b], axis=0) # shape (2,2,3) a1 = tf.unstack(ab, axis=0)

 

其a1的輸出為

[<tf.Tensor 'unstack_1:0' shape=(2, 3) dtype=int32>, <tf.Tensor 'unstack_1:1' shape=(2, 3) dtype=int32>]


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM