最近遇到的一個pytorch報錯:
然后報錯了,這個幾行代碼就是從一個圖片中讀入數據,把bgr模式圖片矩陣轉換為rgb模式,這里采用的是改變矩陣索引,索引倒排 [..., ::-1] 。
看了這個報錯有些懵,因為確實沒想明白這么簡單的操作都會報錯。
后來查了查有些搞明白了,就是pytorch框架通過numpy的array對象生成tensor時要求傳入的numpy的array對象是內存連續的。
如上述:
img = cv2.imread(r"timg.jpg") 得到一個內存連續的array 對象,但是
img = cv2.imread(r"timg.jpg")[...,::-1]
這時將得到的連續的array對象的最后一維索引倒排了,
這樣的話倒排索引得到的array對象如果按照倒排后的索引來看就是內存不連續的,
這樣傳入pytorch中就會報錯了。
這里的解決方法就是傳入一個內存連續的array對象。
方法1.
import numpy as np import torch # 對彩色圖片RGB 進行像素點的kmeans聚類 import cv2 img = cv2.imread(r"timg.jpg")[..., ::-1] img2 = np.ascontiguousarray(img) data = torch.from_numpy(img2).float() img[...]=0 print(img) print(img2)
可以看到 調用
np.ascontiguousarray
我們得到的一個新的內存連續的array, 新舊array不共享內存。
方法2. 同理
import numpy as np import torch # 對彩色圖片RGB 進行像素點的kmeans聚類 import cv2 img = cv2.imread(r"timg.jpg")[..., ::-1].copy() data = torch.from_numpy(img).float()
直接對倒排索引的array對象進行 copy 操作,這樣得到不共享內存的新的array對象, 新生成的array對象自然是內存連續的。
----------------------------------------------------------------------
為什么 pytorch 要求傳入的numpy的array對象必須是內存連續的呢?
可以看:
https://blog.csdn.net/zz2230633069/article/details/93170271
https://zhuanlan.zhihu.com/p/59767914
https://www.cnblogs.com/peixu/articles/13455350.html
大致意思就是說內存連續的array或tensor對象在進行矩陣運算時速度更快。
-------------------------------------------------------------
參考:
https://blog.csdn.net/e01528/article/details/86067489
https://blog.csdn.net/qq_36891953/article/details/95482539
https://blog.csdn.net/u011622208/article/details/89707828