官網默認定義如下:
one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
該函數的功能主要是轉換成one_hot類型的張量輸出。
參數功能如下:
1)indices中的元素指示on_value的位置,不指示的地方都為off_value。indices可以是向量、矩陣。
2)depth表示輸出張量的尺寸,indices中元素默認不超過(depth-1),如果超過,輸出為[0,0,···,0]
3)on_value默認為1
4)off_value默認為0
5)dtype默認為tf.float32
下面用幾個例子說明一下:
1. indices是向量
1 import tensorflow as tf 2 3 indices = [0,2,3,5] 4 depth1 = 6 # indices沒有元素超過(depth-1) 5 depth2 = 4 # indices有元素超過(depth-1) 6 a = tf.one_hot(indices,depth1) 7 b = tf.one_hot(indices,depth2) 8 9 with tf.Session() as sess: 10 print('a = \n',sess.run(a)) 11 print('b = \n',sess.run(b))
運行結果:
# 輸入是一維的,則輸出是一個二維的
a = [[1. 0. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0.] [0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 0. 1.]] # shape=(4,6) b = [[1. 0. 0. 0.] [0. 0. 1. 0.] [0. 0. 0. 1.] [0. 0. 0. 0.]] # shape=(4,4)
2. indices是矩陣
1 import tensorflow as tf 2 3 indices = [[2,3],[1,4]] 4 depth1 = 9 # indices沒有元素超過(depth-1) 5 depth2 = 4 # indices有元素超過(depth-1) 6 a = tf.one_hot(indices,depth1) 7 b = tf.one_hot(indices,depth2) 8 9 with tf.Session() as sess: 10 print('a = \n',sess.run(a)) 11 print('b = \n',sess.run(b))
運行結果:
# 輸入是二維的,則輸出是三維的
a = [[[0. 0. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 1. 0. 0. 0. 0. 0.]] [[0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0.]]] # shape=(2,2,9) b = [[[0. 0. 1. 0.] [0. 0. 0. 1.]] [[0. 1. 0. 0.] [0. 0. 0. 0.]]] # shape=(2,2,4)