pytorch調用caffe的lmdb


代碼:https://github.com/liangX-box/pytorchReadLmdb.git

一. 處理好訓練集和驗證集后,通過caffe的convert_imageset生成lmdb:

1 /usr/softwares/caffe/build/tools/convert_imageset --resize_width=224 --resize_height=224 --gray=true --shuffle=true --encoded=true /usr/data_path/ /usr/data/data_lt.txt /usr/lmdb

(1) /usr/softwares/caffe/build/tools/convert_imageset:caffe中convert_imagese腳本的路徑;

(2) --resize_width, --resize_height: 網絡輸入圖像的寬和高;

(3) --gray: 是否灰度化,如果沒設這個參數,則默認lmdb中像素值為三通道;

(4) --shuffle: 是否隨機打亂輸入圖像列表的順序;

(5) --encoded: 是否對像素做編碼;

(6) /usr/data_path/: 存放圖像的主目錄路徑;

(7) /usr/data/data_lt.txt: 圖像的列表,存放的是圖像的相對路徑和label;

(8) /usr/lmdb: 生成的lmdb的絕對路徑。

二. 提取lmdb中的key,生成key列表:

1 fl = open(saveKey_path, "w")
2 lmdb_env = lmdb.open(lmdb_path)
3 lmdb_txn = lmdb_env.begin()
4 lmdb_cursor = lmdb_txn.cursor()
5 for key, value in lmdb_cursor:
6     fl.write("%s\n" %(key))
7     count += 1
8 fl.close()

  根據caffe中讀取lmdb的方式讀取lmdb中每個元素的key值,詳見代碼get_lmdbKey.py。

三. 根據上述key值列表讀取batch(詳見文件readLmdb.py):

(1) 得到所有圖像的索引,並shuffle;

1 self.indices_total = np.arange(self.dataset_size)
2 np.random.seed(321)
3 np.random.shuffle(self.indices_total)

(2) 定義圖像batch和label batch,從lmdb中取數據放入這兩個batch中:

 //定義image batch和label batch blob(N C H W),其中label_batch后的dtype是為了方便將其轉化為pytorch接受的tensor形式的變量
1
image_batch = np.zeros((self.batch_size, 1, 224, 224), dtype=np.float32) 2 label_batch = np.zeros((self.batch_size), dtype=getattr(np, 'long'))
//按照batch值依次從shuffle后的key list中讀取圖像
3 for i in range(self.batch_size): 4 ind = self.indices_total[self.data_idx] 5 self.data_idx += 1
//若變量data_idx取到了list的最后一個值,則重新shuffle key list 6 if self.data_idx == self.dataset_size: 7 self.data_idx = 0 8 np.random.shuffle(self.indices_total)
//注意以下代碼可能python2和python3的有差異,此處代碼基於python3
9 temp = (self.key_lt[ind]).encode() 10 value = self.txn.get(temp) 11 datum = caffe_pb2.Datum() 12 datum.ParseFromString(value)
//讀取label,並放入label batch中
13 label = float(datum.label)
//讀取圖像數據,對其預處理后放入image batch中
14 encoded = datum.encoded 15 if encoded: 16 stream = BytesIO(datum.data) 17 img = np.uint8(Image.open(stream)) 18 img = img[...,::-1] 19 else: 20 data = caffe.io.datum_to_array(datum) 21 img = np.transpose(data, (1, 2, 0)) 22 img_tmp = img.copy() 23 img_tmp = np.float64(img_tmp) 24 img_tmp -= 127.5 25 img_tmp *= 0.0078125 26 img_tmp = img_tmp[np.newaxis, :] 27 image_batch[i, :] = img_tmp 28 label_batch[i] = label

 (3) 在訓練程序中,送入網絡滿足pytorch的數據形式:

1 train_inputs, train_label = batchClass.GetBatch()
2 pytorch_inputs = (torch.tensor(train_inputs)).cuda()
3 pytorch_labels = (torch.tensor(train_label)).cuda()

 將自己定義的blob通過torch.tensor轉化為pytorch網絡接受的數據形式。

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM