Tensorflow 學習筆記 -----tf.where


TensorFlow函數:tf.where

在之前版本對應函數tf.select

官方解釋:

 

 1 tf.where(input, name=None)`
 2 Returns locations of true values in a boolean tensor.
 3 
 4 This operation returns the coordinates of true elements in input. The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the number of true elements, and the second dimension (columns) represents the coordinates of the true elements. Keep in mind, the shape of the output tensor can vary depending on how many true values there are in input. Indices are output in row-major order.
 5 
 6 For example:
 7 # 'input' tensor is [[True, False]
 8 #                    [True, False]]
 9 # 'input' has two true values, so output has two coordinates.
10 # 'input' has rank of 2, so coordinates have two indices.
11 where(input) ==> [[0, 0],
12                   [1, 0]]
13 
14 # `input` tensor is [[[True, False]
15 #                     [True, False]]
16 #                    [[False, True]
17 #                     [False, True]]
18 #                    [[False, False]
19 #                     [False, True]]]
20 # 'input' has 5 true values, so output has 5 coordinates.
21 # 'input' has rank of 3, so coordinates have three indices.
22 where(input) ==> [[0, 0, 0],
23                   [0, 1, 0],
24                   [1, 0, 1],
25                   [1, 1, 1],
26                   [2, 1, 1]]

 

有兩種用法:

1、tf.where(tensor)

tensor 為一個bool 型張量,where函數將返回其中為true的元素的索引。如上圖官方注釋

2、tf.where(tensor,a,b)

a,b為和tensor相同維度的tensor,將tensor中的true位置元素替換為a中對應位置元素,false的替換為b中對應位置元素。

例:

import tensorflow as tf
import numpy as np
sess=tf.Session()

a=np.array([[1,0,0],[0,1,1]])
a1=np.array([[3,2,3],[4,5,6]])

print(sess.run(tf.equal(a,1)))
print(sess.run(tf.where(tf.equal(a,1),a1,1-a1)))

 

>>[[true,false,false],[false,true,true]]

>>[[3,-1,-2],[-3,5,6]]

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM