tf.py_func函數總結


參考鏈接:https://blog.csdn.net/jiongnima/article/details/80555387

tf.py_func(

  func,

  inp,

  Tout,

  stateful=True,

  name=None

)

在使用tf.py_func的過程中,主要核心是使用前三個參數。

第一個參數func,是最重要的,是一個用戶自定義的函數,輸入numpy array, 輸出也是numpy array, 在該函數中,可以自由使用np.操作。

第二個參數inp,是func函數接收的輸入,是一個列表

第三個參數Tout,指定了func函數返回的numpy array轉化成tensor后的格式,如果是返回個值,就是一個列表元組;如果只有一個返回值,就是一個單獨的dtype類型(當然也可以用列表括起來)

最后來看看tf.py_func的輸出:

Returns:

A list of Tensor or a single Tensor which func computes.

輸出是一個tensor(張量)列表或單個tensor。

到這里,tf.py_func的原理也就逐漸明晰了。首先,tf.py_func接收的是tensor(張量),然后將其轉化為numpy array送入func(就是自定義的那條函數),最后再將func函數輸出的Numpy array轉化為tensor返回。

在使用過程中,有兩個需要注意的地方,第一就是func函數的返回值類型一定要和Tout指定的tensor類型一致。第二就是,如下圖所示,tf.py_func()中的func是脫離Graph的。在func中不能定義可訓練的參數參與網絡訓練(反傳)。

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM