定義tensorflow的輸入節點


定義tensorflow的輸入節點:

tensorflow的輸入節點定義方式基本上有三種,分別是:通過占位符定義、通過字典類型定義、直接定義。其中最常用的就是通過占位符定義、通過字典類型定義。這兩種的區別在於當輸入比較多的時候一般使用字典類型定義。下面通過代碼來進行詳細的解釋:

通過占位符來進行定義

1 X = tf.placeholder("float") # 代表x的輸入值
2 Y = tf.placeholder("float") #代表對應的真實值y

關於tf.placeholder():(參考自https://blog.csdn.net/kdongyi/article/details/82343712)

tf.placeholder(
    dtype,
    shape=None,     name=None )

參數:

  1. dtype:數據類型。常用的是tf.float32,tf.float64等數值類型
  2. shape:數據形狀。默認是None,就是一維值,也可以是多維(比如[2,3], [None, 3]表示列是3,行不定)
  3. name:名稱

為什么要用placeholder?
       Tensorflow的設計理念稱之為計算流圖,在編寫程序時,首先構築整個系統的graph,代碼並不會直接生效,這一點和python的其他數值計算庫(如Numpy等)不同,graph為靜態的,類似於docker中的鏡像。然后,在實際的運行時,啟動一個session,程序才會真正的運行。這樣做的好處就是:避免反復地切換底層程序實際運行的上下文,tensorflow幫你優化整個系統的代碼。我們知道,很多python程序的底層為C語言或者其他語言,執行一行腳本,就要切換一次,是有成本的,tensorflow通過計算流圖的方式,幫你優化整個session需要執行的代碼,還是很有優勢的。

       所以placeholder()函數是在神經網絡構建graph的時候在模型中的占位,此時並沒有把要輸入的數據傳入模型,它只會分配必要的內存。等建立session,在會話中,運行模型的時候通過feed_dict()函數向占位符喂入數據。

代碼示例:

import tensorflow as tf
import numpy as np input1 = tf.placeholder(tf.float32) #輸入值 input2 = tf.placeholder(tf.float32) #輸入值 output = tf.multiply(input1, input2) #輸出是兩個值的乘積 with tf.Session() as sess: print sess.run(output, feed_dict = {input1:[3.], input2: [4.]})

 通過字典類型來定義:

1 # 占位符
2 inputdict = {
3          'x': tf.placeholder("float")
4          'y': tf.placeholder("float")
5 }

直接定義輸入節點(這種方法不常用):

# 模型參數
W = tf.Variable(tf.random_normal([1]), name="weight") b = tf.Variable(tf.zeros([1]), name="bias") # 前向結構,在乘積函數中,train_X就是直接輸入的數據 z = tf.multiply(W, train_X)+ b

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM