pytorch初步學習(一):數據讀取


最近從tensorflow轉向pytorch,感受到了動態調試的方便,也感受到了一些地方的不同。

所有實驗都是基於uint16類型的單通道灰度圖片。

一開始嘗試用opencv中的cv.imread讀取圖片,發現會默認讀8位數據。。。后來還是改用了skimage讀取圖片。一個小坑。

在tensorflow中:

利用append得到數組x_test  [batchsize,width,hight]

x_test = x_test[:, :, :, np.newaxis]
# 占位符
x=tf.placeholder(tf.float32, shape=[None, w, h, 1], name='x')
# 送入網絡tensor維度依次為:batchsize,width,hight,channel

在pytorch中:

arr = np.asarray(img, dtype="float32")
data_x[i, :, :, :] = arr
i += 1
data_y.append(int(item[0]))
data_x = torch.from_numpy(data_x)
data_y = torch.from_numpy(data_y)
dataset = dataf.TensorDataset(data_x, data_y)
loader = dataf.DataLoader(dataset, batch_size=batchsize, shuffle=True)
# 送入網絡的tensor維度依次為:batchsize,channel, width,hight

在tensorflow中需要自己寫一個minibatch函數控制訓練,在pytorch中可以調用dataloader將數據變成torch需要的tensor形式,並且不需要額外寫minibatch函數。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM