使用tensorflow api生成one-hot標簽數據


轉自:http://www.terrylmay.com/2017/06/generate-one-hot-data/

使用tensorflow api生成one-hot標簽數據

在剛開始學習tensorflow的時候, 會有一個最簡單的手寫字符識別的程序供新手開始學習, 在tensorflow.example.tutorial.mnist中已經定義好了mnist的訓練數據以及測試數據. 並且標簽已經從原來的List變成了one-hot的二維矩陣的格式.看了源碼的就知道mnist.input_data.read_data()這個方法中使用的是numpy中的方法來實現標簽的one-hot矩陣化。那么如何使用tensorflow中自帶的api來實現呢?下面我們就來一起看一下需要用到的api吧。

tf.expand_dims 方法 這個函數主要給矩陣或者數組增加一維°, 看代碼可能更加清晰:

import tensorflow as tf
# 比如現在有一個列表
x_data = [1, 2, 3, 4]
x_data_expand = tf.expand_dims(x_data, 0) # x_data的shape是[4], 該函數表示在最前面的位置增加一維, 就會變成[1, 4]
# 而對於[1, 4] 的矩陣加上x_data本身的數據, 那么可以猜想到x_data_expand = [[1, 2, 3, 4]]

x_data_expand_axis1 = tf.expand_dims(x_data, axis=1) # x_data的shape是[4], 而axis=1表示在本來的矩陣的第1列加一維, 所以x_data_expand_axis1是[4, 1] 4行一列的矩陣, 並且把原始數據套進去可知: x_data_expand_axis1 = [[1], [2], [3], [4]], 但是這個axis的參數值不能大於矩陣的列數, 比如矩陣shape為[1, 2, 3] 那么axis=0 則會生成[1, 1, 2, 3], axis=1則會生成[1, 1, 2, 3], axis=2則會生成[1, 2, 1, 3], axis=3則會生成[1, 2, 3, 1]。就是在某一個位置插入一列

tf.concat(values, axis) 該函數用於將兩個相同維度的數據進行合並, 如果指定axis=0那么只需要列數相同即可.否則需要維度都相同 看如下代碼:


import
tensorflow as tf x_data = [[1, 2, 3], [4, 5, 6]] y_data = [[7, 8, 9], [10, 11, 12]] concat_result = tf.concat(values=[x_data, y_data], axis=0) # 這樣的話, 生成的數據是[[1, 2, 3], [7, 8, 9], [4, 5, 6], [10, 11, 12]] concat_result = tf.concat(values=[x_data, y_data], axis=1) # 這樣的話, 生成的數據是[[ 1 2 3 7 8 9], [ 4 5 6 10 11 12]], 三維的甚至更高維度的數據稍后再嘗試

tf.sparse_to_dense() def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0, validate_indices=True, name=None): 該函數指定位置賦值, 並且生成一個維度為output_shape的矩陣;如果output_shape維度為1, 那么sparse_indices只能是一個列表, 如果output_shape為二維矩陣, 那么sparse_indices就可以是矩陣了.

比如如下代碼:

import tensorflow as tf

sparse_indices = [1, 2, 6]
output_shape = tf.zeros([10]).shape
sparse_output = tf.sparse_to_dense(sparse_indices, output_shape, 2, default_value=0) # 生成的結果為:sparse_output:[0 2 2 0 0 0 2 0 0 0] 就是在位置1, 2, 6的位置填充2 其余位置填充0

# 對於二維矩陣的填充也是一樣的, 比如:
sparse_indices = [[0, 1], [2, 4], [4 ,5], [6, 9]]
output_shape = tf.zeros([6, 10]).shape
sparse_output = tf.sparse_to_dense(sparse_indices, output_shape, 1, default_value=0) #生成的數據如下:# 生成的數據如下:sparse_output:
[[0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]]

下面開始實現對原始標簽列表的one-hot化

  import tensorflow as tf

  labels  = [1, 3, 4, 8, 7, 5, 2, 9, 0, 8, 7]
  labels_expand = tf.expand_dims(labels, axis=1) # 這樣label_expand為[11, 1]的數據

  index_expand = tf.expand_dims(tf.range(len(labels)), axis=1) # 與label_expand中的元素一一對應

  concat_result = tf.concat(values=[index_expand, labels_expand], axis=1) # 將上述兩組數據組合在一起

  one_hot = tf.sparse_to_dense(sparse_indices=concat_result, output_shape=tf.zeros([len(labels), 10]).shape, sparse_values=1.0, default_value=0.0)

  session = tf.InteractiveSession()

  print('labels_expand:{}'.format(session.run(labels_expand)))
  print('index_expand:{}'.format(session.run(index_expand)))

  print('concat_result:{}'.format(session.run(concat_result)))
  print('one_hot_of_labels:{}'.format(session.run(one_hot)))

最后的結果如下打印: python labels_expand:[[1] [3] [4] [8] [7] [5] [2] [9] [0] [8] [7]] index_expand:[[ 0] [ 1] [ 2] [ 3] [ 4] [ 5] [ 6] [ 7] [ 8] [ 9] [10]] concat_result:[[ 0 1] [ 1 3] [ 2 4] [ 3 8] [ 4 7] [ 5 5] [ 6 2] [ 7 9] [ 8 0] [ 9 8] [10 7]] one_hot_of_labels:[[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.] [ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.] [ 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.] [ 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [ 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.] [ 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]

這樣就實現了labels的one-hot化。

使用Numpy來實現Label的one-hot化

    import numpy as np

    labels  = [1, 3, 4, 8, 7, 5, 2, 9, 0, 8, 7]
    one_hot_index = np.arange(len(labels)) * 10 + labels

    print ('one_hot_index:{}'.format(one_hot_index))

    one_hot = np.zeros((len(labels), 10))
    one_hot.flat[one_hot_index] = 1

    print('one_hot:{}'.format(one_hot))


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM