關於類型為numpy,TensorFlow.tensor,torch.tensor的shape變化以及相互轉化


https://blog.csdn.net/zz2230633069/article/details/82669546

1.numpy類型:numpy.ndarray  對於圖片讀取之后(H,W,C)或者(batch,H,W,C)

(1)在元素總數不變的情況下:numpy類型的可以直接使用方法numpy.reshape任意改變大小,numpy.expand_dims增加維度,大小是1(這個函數可以參考numpy.expand_dims的用法

(2)元素總數可以變化:scipy.misc.imresize(a,size)

2.TensorFlow的類型:tensorflow.python.framework.ops.tensor  圖片的計算格式(H,W,C)或者(batch,H,W,C)

(1)在元素總數不變的情況下:numpy可以直接作為Tensor的輸入,一旦被放在tf的函數下則失去了numpy的使用方法。tf.expand_dims在指定維度增加1維,大小為1;tf.squeeze剛好相反,刪掉維度為1的軸(這兩個函數可以參考tf.expand_dims和tf.squeeze函數);

(2)元素總數可以變化:

  1. '''
  2. tf和numpy之間的轉化
  3. '''
  4. import tensorflow as tf
  5.  
  6. a= tf.zeros(( 3,2))
  7. sess=tf.Session()
  8. sess.run(tf.global_variables_initializer())
  9.  
  10. print( "type(a)=",type(a)) # type(a)= <class 'tensorflow.python.framework.ops.Tensor'>
  11.  
  12. #轉化為numpy數組
  13. a_np=a.eval(session=sess)
  14. print( "type(a_np)=",type(a_np)) # type(a_np)= <class 'numpy.ndarray'>
  15. #轉化為tensor
  16. a2= tf.convert_to_tensor(a_np)
  17. print( "type(a2)=",type(a2)) # type(a2)= <class 'tensorflow.python.framework.ops.Tensor'>
  18.  
  19.  

3.torch類型:torch.tensor  圖片的計算格式是(C,H,W)或者(batch,C,H,W)

numpy類型不能直接作為Tensor的輸入,所以在運用torch之前一定要進行轉化。

  1. from PIL import Image
  2. import torch
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. a=Image.open( '/home/zzp/um_lane_000000.png') # 加載圖片數據,返回的是一個PIL類型
  6. b=np.array(a).astype(np.float32) # 先將PIL類型轉化成numpy類型,並且把數據變成浮點數
  7. c=b.transpose(( 2,0,1)) # 調整成torch的通道
  8. d=torch.from_numpy(c).float() # 再將numpy類型轉化成torch.tensor類型
  9.  
  10. # 或者另外一種加載圖片的方式
  11. import scipy.misc
  12. import torch
  13. import numpy as np
  14. a=scipy.misc.imread( '/home/zzp/um_lane_000000.png') # 加載圖片數據,返回的是一個numpy類型
  15. c=a.transpose(( 2,0,1)).astype(np.float32) # 直接調整成torch的通道,不需要轉化成numpy類型了,還是要變為浮點數
  16. d=torch.from_numpy(c).float() # 再將numpy類型轉化成torch.tensor類型
  17.  
  18. # 三種加載圖像的方法
  19. a=Image.open( '/home/zzp/um_lane_000000.png')
  20. b=scipy.misc.imread( '/home/zzp/um_lane_000000.png')
  21. c=plt.imread( '/home/zzp/um_lane_000000.png')
  22. #顯示

(1)在元素總數不變的情況下

 

(2)元素總數可以變化


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM