tf.boolean_mask 的作用是 通過布爾值 過濾元素
def boolean_mask(tensor, mask, name="boolean_mask", axis=None): """Apply boolean mask to tensor.
tensor:被過濾的元素
mask:一堆 bool 值,它的維度不一定等於 tensor
return: mask 為 true 對應的 tensor 的元素
當 tensor 與 mask 維度一致時,return 一維
先看個 一維 例子
# 1-D example tensor = [0, 1, 2, 3] mask = np.array([True, False, True, False]) out = tf.boolean_mask(tensor, mask) print(sess.run(out)) # [0, 2] print(out.shape) # (?,)
再看看 mask 與 tensor 維度不同的例子
tensor = [[1, 2], [3, 4], [5, 6]] mask = np.array([True, False, True]) # mask 與 tensor 維度不同 out2 = tf.boolean_mask(tensor, mask) print(sess.run(out2)) # [[1, 2], [5, 6]] print(out2.shape) # (?, 2)
mask 可以用一個函數代替
# 3-D tensor = tf.constant([ [[2,4],[4,1]], [[6,8],[2,1]]],tf.float32) mask = tensor > 2 # 濾波器 mask 與 tensor 相同維度 out3 = tf.boolean_mask(tensor, mask) print(sess.run(tensor)) print(sess.run(mask)) # [[[False True] [ True False]] # [[ True True] [False False]]] print(sess.run(out3)) # [4. 4. 6. 8.] 輸出一維 print(out3.shape) # (?,)
shape
上面的 shape 是怎么回事呢?有如下規則
假設 tensor.rank=4(m,n,p,q),則
(1)當mask.shape=(m,n,p,q),結果返回(?,)
(2)當mask.shape=(m,n,p),結果返回(?,q),表示 q 維度沒有過濾
(3)當mask.shape=(m,n),結果返回(?,p,q)
(4)當mask.shape=(m),結果返回(?,n,p,q)
參考資料:
https://blog.csdn.net/qq_29444571/article/details/84574526
https://www.w3cschool.cn/doc_tensorflow_python/tensorflow_python-tf-boolean_mask.html