代碼:https://github.com/liangX-box/pytorchReadLmdb.git
一. 處理好訓練集和驗證集后,通過caffe的convert_imageset生成lmdb:
1 /usr/softwares/caffe/build/tools/convert_imageset --resize_width=224 --resize_height=224 --gray=true --shuffle=true --encoded=true /usr/data_path/ /usr/data/data_lt.txt /usr/lmdb
(1) /usr/softwares/caffe/build/tools/convert_imageset:caffe中convert_imagese腳本的路徑;
(2) --resize_width, --resize_height: 網絡輸入圖像的寬和高;
(3) --gray: 是否灰度化,如果沒設這個參數,則默認lmdb中像素值為三通道;
(4) --shuffle: 是否隨機打亂輸入圖像列表的順序;
(5) --encoded: 是否對像素做編碼;
(6) /usr/data_path/: 存放圖像的主目錄路徑;
(7) /usr/data/data_lt.txt: 圖像的列表,存放的是圖像的相對路徑和label;
(8) /usr/lmdb: 生成的lmdb的絕對路徑。
二. 提取lmdb中的key,生成key列表:
1 fl = open(saveKey_path, "w") 2 lmdb_env = lmdb.open(lmdb_path) 3 lmdb_txn = lmdb_env.begin() 4 lmdb_cursor = lmdb_txn.cursor() 5 for key, value in lmdb_cursor: 6 fl.write("%s\n" %(key)) 7 count += 1 8 fl.close()
根據caffe中讀取lmdb的方式讀取lmdb中每個元素的key值,詳見代碼get_lmdbKey.py。
三. 根據上述key值列表讀取batch(詳見文件readLmdb.py):
(1) 得到所有圖像的索引,並shuffle;
1 self.indices_total = np.arange(self.dataset_size) 2 np.random.seed(321) 3 np.random.shuffle(self.indices_total)
(2) 定義圖像batch和label batch,從lmdb中取數據放入這兩個batch中:
//定義image batch和label batch blob(N C H W),其中label_batch后的dtype是為了方便將其轉化為pytorch接受的tensor形式的變量
1 image_batch = np.zeros((self.batch_size, 1, 224, 224), dtype=np.float32) 2 label_batch = np.zeros((self.batch_size), dtype=getattr(np, 'long'))
//按照batch值依次從shuffle后的key list中讀取圖像 3 for i in range(self.batch_size): 4 ind = self.indices_total[self.data_idx] 5 self.data_idx += 1
//若變量data_idx取到了list的最后一個值,則重新shuffle key list 6 if self.data_idx == self.dataset_size: 7 self.data_idx = 0 8 np.random.shuffle(self.indices_total)
//注意以下代碼可能python2和python3的有差異,此處代碼基於python3 9 temp = (self.key_lt[ind]).encode() 10 value = self.txn.get(temp) 11 datum = caffe_pb2.Datum() 12 datum.ParseFromString(value)
//讀取label,並放入label batch中 13 label = float(datum.label)
//讀取圖像數據,對其預處理后放入image batch中 14 encoded = datum.encoded 15 if encoded: 16 stream = BytesIO(datum.data) 17 img = np.uint8(Image.open(stream)) 18 img = img[...,::-1] 19 else: 20 data = caffe.io.datum_to_array(datum) 21 img = np.transpose(data, (1, 2, 0)) 22 img_tmp = img.copy() 23 img_tmp = np.float64(img_tmp) 24 img_tmp -= 127.5 25 img_tmp *= 0.0078125 26 img_tmp = img_tmp[np.newaxis, :] 27 image_batch[i, :] = img_tmp 28 label_batch[i] = label
(3) 在訓練程序中,送入網絡滿足pytorch的數據形式:
1 train_inputs, train_label = batchClass.GetBatch() 2 pytorch_inputs = (torch.tensor(train_inputs)).cuda() 3 pytorch_labels = (torch.tensor(train_label)).cuda()
將自己定義的blob通過torch.tensor轉化為pytorch網絡接受的數據形式。