TensorFlow的tf.where函數詳解與例子


官方說明:
If both x and y are None, then this operation returns the coordinates of true elements of condition. The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input. Indices are output in row-major order.

If both non-None, condition, x and y must be broadcastable to the same shape.

The condition tensor acts as a mask that chooses, based on the value at each element, whether the corresponding element / row in the output should be taken from x (if true) or y (if false).

官方文檔很抽象,必須結合例子來理解。一共有兩種用法,分別是帶有xy參數和不帶這兩個參數的用法。

用法1

a1=np.array([[1,0,0],[0,1,1]]) 
a2=np.array([[3,2,3],[4,5,6]])
tf.where(tf.equal(a1,1),a1,a2)

輸出的結果是

<tf.Tensor: id=13, shape=(2, 3), dtype=int64, numpy=
array([[1, 2, 3],
       [4, 1, 1]])>

也就是,當condition為真,也就是tf.equal(a1,1,即a1中的元素為1,返回的數組中所在位置元素來自a1,否則來自b1。輸出的數組中,原數組a1不等於1的元素被替換成了對應位置b1中的元素。
再來一個例子,

tf.where(tf.equal(a1,1),a1,100+a1)

輸出的結果是

<tf.Tensor: id=19, shape=(2, 3), dtype=int64, numpy=
array([[  1, 100, 100],
       [100,   1,   1]])>

數組a1中不等於1的元素,其值加上100。

用法2

不帶xy參數的時候,返回滿足condition的元素所在位置。需要關注的是返回值的形式。

tf.where(tf.equal(a,1)) 

輸出結果

<tf.Tensor: id=55, shape=(3, 2), dtype=int64, numpy=
array([[0, 0],
       [1, 1],
       [1, 2]])>

這是一個(3, 2)數組,行數表示滿足條件的元素的數目a1中一共有3個元素為1,所有行數為3。每一列代表的是符合條件的元素的坐標,比如第一個元素[0,0],表示第一個滿足條件的元素的index是(0,0)。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM