[開發技巧]·TensorFlow中numpy與tensor數據相互轉化
個人主頁–> https://xiaosongshine.github.io/
- 問題描述
在我們使用TensorFlow進行深度學習訓練時,很多時候都是與Numpy數據打招呼,例如我們csv或者照片數據等。
但是我們都知道,TensorFlow訓練時都是使用Tensor來存儲變量的,並且網絡輸出的結果也是Tensor。
一般情況下我們不會感受到Numpy與Tensor之間的區別,因為TensorFlow網絡在輸入Numpy數據時會自動轉換為Tensor來處理。
但是在輸出網絡時,輸出的結果仍為Tensor,當我們要用這些結果去執行只能由Numpy數據來執行的操作時就會出現莫名其妙的錯誤。
例如,當我想要用自編碼器與解碼器輸出的結果使用matplotlib顯示時就會報錯
TypeError: Image data cannot be converted to float
解決方法
有時候解決起來很簡單,就是錯誤比較難找到,所以我推薦的方法為將數據進行顯式的轉化。
- Numpy2Tensor
雖然TensorFlow網絡在輸入Numpy數據時會自動轉換為Tensor來處理,但是我們自己也可以去顯式的轉換:
data_tensor= tf.convert_to_tensor(data_numpy)
- Tensor2Numpy
網絡輸出的結果仍為Tensor,當我們要用這些結果去執行只能由Numpy數據來執行的操作時就會出現莫名其妙的錯誤。解決方法:
with tf.Session() as sess: data_numpy = data_tensor.eval()