背景
在深度學習的時候,如果你的batch size調的很大,或者你每次獲取一個batch需要許多的預操作,那么pytorch的Dataloader獲取一個batch就會花費較多的時間,那么訓練的時候就會出現GPU等CPU的情況,訓練的效率就會下降。
為了應對這種情況,Tensorflow有TFrecord,但是Pytorch沒有對應的數據格式,在查詢各類資料之后,我決定使用LMDB這個數據庫
LMDB是一種數據庫,可以實現多進程訪問,訪問簡單,而且不需要把全部文件讀入內存,總而言之就是速度很快
方法
我們想用LMDB來讀取的話,首先就需要將我們原始的數據集轉換為LMDB的數據格式,然后在訓練的時候讀取這個文件就行了。首先我們來實現將原始數據集轉換為LMDB的過程:
轉換為LMDB格式
相必大家在尋找一個更加快速的Dataloader的時候,已經寫好了Pytorch常規的Dataloader,我們這里就可以利用上這個已有的Dataloader
首先打開一個lmdb數據庫文件,如果之前有用過其他文件數據庫的話,會發現這有點相似~
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True)
然后准備向其中寫入數據:
txn = db.begin(write=True)
然后就是將圖片放到lmdb中
txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow(
(pic.numpy())
))
最后
txn.commit()
總體的操作與數據庫一樣,總體就是打開數據庫,放入數據,最后commit
最后放出完整的方法:
def data2lmdb(dataloader, db_path='data_lmdb', name="train", write_frequency=50):
"""
Args:
dataloader: the general dataloader of the dataset, e.g torch.utils.data.DataLoader
db_path: the path you want to save the lmdb file
name: train or test
write_frequency: Write once every ? rounds
Returns:
None
"""
if not os.path.exists(db_path):
os.makedirs(db_path)
lmdb_path = os.path.join(db_path, "%s.lmdb" % name)
isdir = os.path.isdir(lmdb_path)
print("Generate LMDB to %s" % lmdb_path)
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True)
txn = db.begin(write=True)
for idx, data in enumerate(dataloader):
# get data from dataloader
if name == 'train':
pic = data
elif name == 'test':
pic = data
else:
raise 'unexpect name :{}'.format(name)
# put data to lmdb dataset
# {idx, (in_LDRs, in_HDRs, ref_HDRs)}
txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow(
(pic.numpy())
))
if idx % write_frequency == 0:
print("[%d/%d]" % (idx, len(dataloader)))
txn.commit()
txn = db.begin(write=True)
# finish iterating through dataset
txn.commit()
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_pyarrow(keys))
txn.put(b'__len__', dumps_pyarrow(len(keys)))
print("Flushing database ...")
db.sync()
db.close()
上面的函數適合我的使用場景,如果你需要搬去用的話,需要修改dataloader輸出的地方,根據自己的dataloader來寫
讀取LMDB文件
這里放上IMDB的Dataloader
class ImageFolderLMDBTest(data.Dataset):
def __init__(self, db_path):
self.db_path = db_path
self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
# self.length = txn.stat()['entries'] - 1
self.length = pa.deserialize(txn.get(b'__len__'))
self.keys = pa.deserialize(txn.get(b'__keys__'))
def __getitem__(self, index):
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(self.keys[index])
unpacked = pa.deserialize(byteflow)
# load image
# 這里寫你寫入lmdb時的數據,上面的我寫入了pic,這里展開就還是pic
pic = unpacked
return pic
def to_tensor(self, img):
img_t = torch.from_numpy(img.copy())
if isinstance(img_t, torch.ByteTensor):
return img_t.float().div(255)
else:
return img_t
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
這個Dataloader的結構和一般的Dataloader一樣,需要注意的我也寫在了注釋里面,其實就是根據你寫入的東西不同,從LMDB里取出的東西也不一樣
后記
在用上這個LMDB文件格式之后,數據的讀取速度也快了很多,GPU也終於不會歇着了🤣