官方說明:
If both x and y are None, then this operation returns the coordinates of true elements of condition. The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input. Indices are output in row-major order.
If both non-None, condition, x and y must be broadcastable to the same shape.
The condition tensor acts as a mask that chooses, based on the value at each element, whether the corresponding element / row in the output should be taken from x (if true) or y (if false).
官方文檔很抽象,必須結合例子來理解。一共有兩種用法,分別是帶有x
和y
參數和不帶這兩個參數的用法。
用法1
a1=np.array([[1,0,0],[0,1,1]])
a2=np.array([[3,2,3],[4,5,6]])
tf.where(tf.equal(a1,1),a1,a2)
輸出的結果是
<tf.Tensor: id=13, shape=(2, 3), dtype=int64, numpy=
array([[1, 2, 3],
[4, 1, 1]])>
也就是,當condition
為真,也就是tf.equal(a1,1
,即a1
中的元素為1,返回的數組中所在位置元素來自a1
,否則來自b1
。輸出的數組中,原數組a1
不等於1的元素被替換成了對應位置b1
中的元素。
再來一個例子,
tf.where(tf.equal(a1,1),a1,100+a1)
輸出的結果是
<tf.Tensor: id=19, shape=(2, 3), dtype=int64, numpy=
array([[ 1, 100, 100],
[100, 1, 1]])>
數組a1
中不等於1的元素,其值加上100。
用法2
不帶x
和y
參數的時候,返回滿足condition的元素所在位置。需要關注的是返回值的形式。
tf.where(tf.equal(a,1))
輸出結果
<tf.Tensor: id=55, shape=(3, 2), dtype=int64, numpy=
array([[0, 0],
[1, 1],
[1, 2]])>
這是一個(3, 2)
數組,行數表示滿足條件的元素的數目。a1
中一共有3個元素為1,所有行數為3。每一列代表的是符合條件的元素的坐標,比如第一個元素[0,0]
,表示第一個滿足條件的元素的index是(0,0)。