PyTorch使用LMDB數據庫加速文件讀取
原始文檔:https://www.yuque.com/lart/ugkv9f/hbnym1
對於數據庫的了解較少,文章中大部分的介紹主要來自於各種博客和LMDB的文檔,但是文檔中的介紹,默認是已經了解了數據庫的許多知識,這導致目前只能囫圇吞棗,待之后仔細了解后再重新補充內容。
背景介紹
文章https://blog.csdn.net/jyl1999xxxx/article/details/53942824中介紹了使用LMDB的原因:
Caffe使用LMDB來存放訓練/測試用的數據集,以及使用網絡提取出的feature(為了方便,以下還是統稱數據集)。數據集的結構很簡單,就是大量的矩陣/向量數據平鋪開來。數據之間沒有什么關聯,數據內沒有復雜的對象結構,就是向量和矩陣。既然數據並不復雜,Caffe就選擇了LMDB這個簡單的數據庫來存放數據。
LMDB的全稱是Lightning Memory-Mapped Database,閃電般的內存映射數據庫。它文件結構簡單,一個文件夾,里面一個數據文件,一個鎖文件。數據隨意復制,隨意傳輸。它的訪問簡單,不需要運行單獨的數據庫管理進程,只要在訪問數據的代碼里引用LMDB庫,訪問時給文件路徑即可。
圖像數據集歸根究底從圖像文件而來。引入數據庫存放數據集,是為了減少IO開銷。讀取大量小文件的開銷是非常大的,尤其是在機械硬盤上。LMDB的整個數據庫放在一個文件里,避免了文件系統尋址的開銷。LMDB使用內存映射的方式訪問文件,使得文件內尋址的開銷非常小,使用指針運算就能實現。數據庫單文件還能減少數據集復制/傳輸過程的開銷。一個幾萬,幾十萬文件的數據集,不管是直接復制,還是打包再解包,過程都無比漫長而痛苦。LMDB數據庫只有一個文件,你的介質有多塊,就能復制多快,不會因為文件多而慢如蝸牛。
在文章http://shuokay.com/2018/05/14/python-lmdb/中類似提到:
為什么要把圖像數據轉換成大的二進制文件?
簡單來說,是因為讀寫小文件的速度太慢。那么,不禁要問,圖像數據也是二進制文件,單個大的二進制文件例如 LMDB 文件也是二進制文件,為什么單個圖像讀寫速度就慢了呢?這里分兩種情況解釋。
- 機械硬盤的情況:機械硬盤的每次讀寫啟動時間比較長,例如磁頭的尋道時間占比很高,因此,如果單個小文件讀寫,尤其是隨機讀寫單個小文件的時候,這個尋道時間占比就會很高,最后導致大量讀寫小文件的時候時間會很浪費;
- NFS 的情況:在 NFS 的場景下,系統的一次讀寫首先要進行上百次的網絡通訊,並且這個通訊次數和文件的大小無關。因此,如果是讀寫小文件,這個網絡通訊時間占據了整個讀寫時間的大部分。
固態硬盤的情況下應該也會有一些類似的開銷,目前沒有研究過。
總而言之,使用LMDB可以為我們的數據讀取進行加速。
具體操作
LMDB主要類
pip install imdb
lmdb.Environment
lmdb.open() 這個方法實際上是 class lmdb.Environment(path, map_size=10485760, subdir=True, readonly=False, metasync=True, sync=True, map_async=False, mode=493, create=True, readahead=True, writemap=False, meminit=True, max_readers=126, max_dbs=0, max_spare_txns=1, lock=True) 的一個別名(shortcut),二者是等價的。關於這個類:https://lmdb.readthedocs.io/en/release/#environment-class
這是數據庫環境的結構。 一個環境可能包含多個數據庫,所有數據庫都駐留在同一共享內存映射和基礎磁盤文件中。要寫入環境,必須創建事務(Transaction)。 允許同時進行一次寫入事務,但是即使存在寫入事務,讀取事務的數量也沒有限制。
幾個重要的實例方法:
- begin(db=None, parent=None, write=False, buffers=False): 可以調用事務類
lmdb.Transaction - open_db(key=None, txn=None, reverse_key=False, dupsort=False, create=True, integerkey=False, integerdup=False, dupfixed=False): 打開一個數據庫,返回一個不透明的句柄。重復
Environment.open_db()調用相同的名稱將返回相同的句柄。作為一個特殊情況,主數據庫總是開放的。命名數據庫是通過在主數據庫中存儲一個特殊的描述符來實現的。環境中的所有數據庫共享相同的文件。因為描述符存在於主數據庫中,所以如果已經存在與數據庫名稱匹配的key,創建命名數據庫的嘗試將失敗。此外,查找和枚舉可以看到key。如果主數據庫keyspace與命名數據庫使用的名稱沖突,則將主數據庫的內容移動到另一個命名數據庫。
>>> env = lmdb.open('/tmp/test', max_dbs=2)
>>> with env.begin(write=True) as txn
... txn.put('somename', 'somedata')
>>> # Error: database cannot share name of existing key!
>>> subdb = env.open_db('somename')
lmdb.Transaction
這和事務對象有關。
class lmdb.Transaction(env, db=None, parent=None, write=False, buffers=False) 。
關於這個類的參數:https://lmdb.readthedocs.io/en/release/#transaction-class
所有操作都需要事務句柄,事務可以是只讀或讀寫的。寫事務可能不會跨越線程。事務對象實現了上下文管理器協議,因此即使面對未處理的異常,也可以可靠地釋放事務:
# Transaction aborts correctly:
with env.begin(write=True) as txn:
crash()
# Transaction commits automatically:
with env.begin(write=True) as txn:
txn.put('a', 'b')
這個類的實例包含着很多有用的操作方法。
- abort(): 中止掛起的事務。重復調用
abort()在之前成功的commit()或abort()后或者在相關環境關閉后是沒有效果的。 - commit(): 提交掛起的事務。
- cursor(db=None): Shortcut for
lmdb.Cursor(db, self) - delete(key, value='', db=None): Delete a key from the database.
- key: The key to delete.
- value:如果數據庫是以
dupsort = True打開的,並且value不是空的bytestring,則刪除僅與此(key, value)對匹配的元素,否則該key的所有值都將被刪除。 - Returns
Trueif at least one key was deleted.
- drop(db, delete=True): 刪除命名數據庫中的所有鍵,並可選地刪除命名數據庫本身。刪除命名數據庫會導致其不可用,並使現有cursors無效。
- get(key, default=None, db=None): 獲取匹配鍵的第一個值,如果鍵不存在,則返回默認值。cursor必須用於獲取
dupsort = True數據庫中的key的所有值。 - id(): 返回事務的ID。這將返回與此事務相關聯的標識符。對於只讀事務,這對應於正在讀取的快照; 並發讀取器通常具有相同的事務ID。
- pop(key, db=None): 使用臨時cursor調用
Cursor.pop()。- db: 要操作的命名數據庫。如果未指定,默認為事務構造函數被給定的數據庫。
- put(key, value, dupdata=True, overwrite=True, append=False, db=None): 存儲一條記錄(record),如果記錄被寫入,則返回
True,否則返回False,以指示key已經存在並且overwrite = False。成功后,cursor位於新記錄上。- key: Bytestring key to store.
- value: Bytestring value to store.
- dupdata: 如果
True,並且數據庫是用dupsort = True打開的,如果給定key已經存在,則添加鍵值對作為副本。否則覆蓋任何現有匹配的key。 - overwrite: If
False, do not overwrite any existing matching key. - append: 如果為
True,則將對附加到數據庫末尾,而不首先比較其順序。附加不大於現有最高key的key將導致損壞。 - db: 要操作的命名數據庫。如果未指定,默認為事務構造函數被給定的數據庫。
- replace(key, value, db=None): 使用臨時cursor調用
Cursor.replace(). - db: Named database to operate on. If unspecified, defaults to the database given to the Transaction constructor.
- stat(db): Return statistics like
Environment.stat(), except for a single DBI.dbmust be a database handle returned byopen_db().
Imdb.Cursor
class lmdb.Cursor(db, txn) 是用於在數據庫中導航(navigate)的結構。
- db: Database to navigate.
- txn: Transaction to navigate.
As a convenience, Transaction.cursor() can be used to quickly return a cursor:
>>> env = lmdb.open('/tmp/foo')
>>> child_db = env.open_db('child_db')
>>> with env.begin() as txn:
... cursor = txn.cursor() # Cursor on main database.
... cursor2 = txn.cursor(child_db) # Cursor on child database.
游標以未定位的狀態開始。如果在這種狀態下使用 iternext() 或 iterprev() ,那么迭代將分別從開始處和結束處開始。迭代器直接使用游標定位,這意味着在同一游標上存在多個迭代器時會產生奇怪的行為。
從Python綁定的角度來看,一旦任何掃描或查找方法(例如
next()、prev_nodup()、set_range())返回False或引發異常,游標將返回未定位狀態。這主要是為了確保在面對任何錯誤條件時語義的安全性和一致性。
當游標返回到未定位的狀態時,它的key()和value()返回空字符串,表示沒有活動的位置,盡管在內部,LMDB游標可能仍然有一個有效的位置。
這可能會導致在迭代dupsort=True數據庫的key時出現一些令人吃驚的行為,因為iternext_dup()等方法將導致游標顯示為未定位,盡管它返回False只是為了表明當前鍵沒有更多的值。在這種情況下,簡單地調用next()將導致在下一個可用鍵處繼續迭代。
This behaviour may change in future.
Iterator methods such as iternext() and iterprev() accept keys and values arguments. If both are True , then the value of item() is yielded on each iteration. If only keys is True , key() is yielded, otherwise only value() is yielded.
在迭代之前,游標可能定位在數據庫中的任何位置
>>> with env.begin() as txn:
... cursor = txn.cursor()
... if not cursor.set_range('5'): # Position at first key >= '5'.
... print('Not found!')
... else:
... for key, value in cursor: # Iterate from first key >= '5'.
... print((key, value))
不需要迭代來導航,有時會導致丑陋或低效的代碼。在迭代順序不明顯的情況下,或者與正在讀取的數據相關的情況下,使用 set_key() 、 set_range() 、 key() 、 value() 和 item() 可能是更好的選擇。
>>> # Record the path from a child to the root of a tree.
>>> path = ['child14123']
>>> while path[-1] != 'root':
... assert cursor.set_key(path[-1]), \
... 'Tree is broken! Path: %s' % (path,)
... path.append(cursor.value())
幾個實例方法:
- set_key(key): Seek exactly to key, returning
Trueon success orFalseif the exact key was not found. 對於set_key(),空字節串是錯誤的。對於使用dupsort=True打開的數據庫,移動到鍵的第一個值(復制)。 - set_range(key): Seek to the first
keygreater than or equal tokey, returningTrueon success, orFalseto indicatekeywas past end of database. Behaves likefirst()if key is the empty bytestring. 對於使用dupsort=True打開的數據庫,移動到鍵的第一個值(復制)。 - get(key, default=None): Equivalent to
set_key(), exceptvalue()is returned when key is found, otherwise default. - item(): Return the current
(key, value)pair. - key(): Return the current key.
- value(): Return the current value.
操作流程
概況地講,操作LMDB的流程是:
- 通過
env = lmdb.open()打開環境 - 通過
txn = env.begin()建立事務 - 通過
txn.put(key, value)進行插入和修改 - 通過
txn.delete(key)進行刪除 - 通過
txn.get(key)進行查詢 - 通過
txn.cursor()進行遍歷 - 通過
txn.commit()提交更改
這里要注意:
put和delete后一定注意要commit,不然根本沒有存進去- 每一次
commit后,需要再定義一次txn=env.begin(write=True)
來自https://github.com/kophy/py4db的代碼:
#!/usr/bin/env python
import lmdb
import os, sys
def initialize():
env = lmdb.open("students");
return env;
def insert(env, sid, name):
txn = env.begin(write = True);
txn.put(str(sid), name);
txn.commit();
def delete(env, sid):
txn = env.begin(write = True);
txn.delete(str(sid));
txn.commit();
def update(env, sid, name):
txn = env.begin(write = True);
txn.put(str(sid), name);
txn.commit();
def search(env, sid):
txn = env.begin();
name = txn.get(str(sid));
return name;
def display(env):
txn = env.begin();
cur = txn.cursor();
for key, value in cur:
print (key, value);
env = initialize();
print "Insert 3 records."
insert(env, 1, "Alice");
insert(env, 2, "Bob");
insert(env, 3, "Peter");
display(env);
print "Delete the record where sid = 1."
delete(env, 1);
display(env);
print "Update the record where sid = 3."
update(env, 3, "Mark");
display(env);
print "Get the name of student whose sid = 3."
name = search(env, 3);
print name;
env.close();
os.system("rm -r students");
創建圖像數據集
這里主要借鑒自https://github.com/open-mmlab/mmsr/blob/master/codes/data_scripts/create_lmdb.py的代碼。
改寫為:
import glob
import os
import pickle
import sys
import cv2
import lmdb
import numpy as np
from tqdm import tqdm
def main(mode):
proj_root = '/home/lart/coding/TIFNet'
datasets_root = '/home/lart/Datasets/'
lmdb_path = os.path.join(proj_root, 'datasets/ECSSD.lmdb')
data_path = os.path.join(datasets_root, 'RGBSaliency', 'ECSSD/Image')
if mode == 'creating':
opt = {
'name': 'TrainSet',
'img_folder': data_path,
'lmdb_save_path': lmdb_path,
'commit_interval': 100, # After commit_interval images, lmdb commits
'num_workers': 8,
}
general_image_folder(opt)
elif mode == 'testing':
test_lmdb(lmdb_path, index=1)
def general_image_folder(opt):
"""
Create lmdb for general image folders
If all the images have the same resolution, it will only store one copy of resolution info.
Otherwise, it will store every resolution info.
"""
img_folder = opt['img_folder']
lmdb_save_path = opt['lmdb_save_path']
meta_info = {'name': opt['name']}
if not lmdb_save_path.endswith('.lmdb'):
raise ValueError("lmdb_save_path must end with 'lmdb'.")
if os.path.exists(lmdb_save_path):
print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
sys.exit(1)
# read all the image paths to a list
print('Reading image path list ...')
all_img_list = sorted(glob.glob(os.path.join(img_folder, '*')))
# cache the filename, 這里的文件名必須是ascii字符
keys = []
for img_path in all_img_list:
keys.append(os.path.basename(img_path))
# create lmdb environment
# 估算大概的映射空間大小
data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
print('data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(all_img_list)
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
# map_size:
# Maximum size database may grow to; used to size the memory mapping. If database grows larger
# than map_size, an exception will be raised and the user must close and reopen Environment.
# write data to lmdb
txn = env.begin(write=True)
resolutions = []
tqdm_iter = tqdm(enumerate(zip(all_img_list, keys)), total=len(all_img_list), leave=False)
for idx, (path, key) in tqdm_iter:
tqdm_iter.set_description('Write {}'.format(key))
key_byte = key.encode('ascii')
data = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if data.ndim == 2:
H, W = data.shape
C = 1
else:
H, W, C = data.shape
resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
txn.put(key_byte, data)
if (idx + 1) % opt['commit_interval'] == 0:
txn.commit()
# commit 之后需要再次 begin
txn = env.begin(write=True)
txn.commit()
env.close()
print('Finish writing lmdb.')
# create meta information
# check whether all the images are the same size
assert len(keys) == len(resolutions)
if len(set(resolutions)) <= 1:
meta_info['resolution'] = [resolutions[0]]
meta_info['keys'] = keys
print('All images have the same resolution. Simplify the meta info.')
else:
meta_info['resolution'] = resolutions
meta_info['keys'] = keys
print('Not all images have the same resolution. Save meta info for each image.')
pickle.dump(meta_info, open(os.path.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
print('Finish creating lmdb meta info.')
def test_lmdb(dataroot, index=1):
env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), "rb"))
print('Name: ', meta_info['name'])
print('Resolution: ', meta_info['resolution'])
print('# keys: ', len(meta_info['keys']))
# read one image
key = meta_info['keys'][index]
print('Reading {} for test.'.format(key))
with env.begin(write=False) as txn:
buf = txn.get(key.encode('ascii'))
img_flat = np.frombuffer(buf, dtype=np.uint8)
C, H, W = [int(s) for s in meta_info['resolution'][index].split('_')]
img = img_flat.reshape(H, W, C)
cv2.namedWindow('Test')
cv2.imshow('Test', img)
cv2.waitKeyEx()
if __name__ == "__main__":
# mode = creating or testing
main(mode='creating')
配合DataLoader
這里僅對訓練集進行LMDB處理,測試機依舊使用的原始的讀取圖片的方式。
import os
import pickle
import lmdb
import numpy as np
from PIL import Image
from prefetch_generator import BackgroundGenerator
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from utils import joint_transforms
def _get_paths_from_lmdb(dataroot):
"""get image path list from lmdb meta info"""
meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'),
'rb'))
paths = meta_info['keys']
sizes = meta_info['resolution']
if len(sizes) == 1:
sizes = sizes * len(paths)
return paths, sizes
def _read_img_lmdb(env, key, size):
"""read image from lmdb with key (w/ and w/o fixed size)
size: (C, H, W) tuple"""
with env.begin(write=False) as txn:
buf = txn.get(key.encode('ascii'))
img_flat = np.frombuffer(buf, dtype=np.uint8)
C, H, W = size
img = img_flat.reshape(H, W, C)
return img
def _make_dataset(root, prefix=('.jpg', '.png')):
img_path = os.path.join(root, 'Image')
gt_path = os.path.join(root, 'Mask')
img_list = [
os.path.splitext(f)[0] for f in os.listdir(gt_path)
if f.endswith(prefix[1])
]
return [(os.path.join(img_path, img_name + prefix[0]),
os.path.join(gt_path, img_name + prefix[1]))
for img_name in img_list]
class TestImageFolder(Dataset):
def __init__(self, root, in_size, prefix):
self.imgs = _make_dataset(root, prefix=prefix)
self.test_img_trainsform = transforms.Compose([
# 輸入的如果是一個tuple,則按照數據縮放,但是如果是一個數字,則按比例縮放到短邊等於該值
transforms.Resize((in_size, in_size)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __getitem__(self, index):
img_path, gt_path = self.imgs[index]
img = Image.open(img_path).convert('RGB')
img_name = (img_path.split(os.sep)[-1]).split('.')[0]
img = self.test_img_trainsform(img)
return img, img_name
def __len__(self):
return len(self.imgs)
class TrainImageFolder(Dataset):
def __init__(self, root, in_size, scale=1.5, use_bigt=False):
self.use_bigt = use_bigt
self.in_size = in_size
self.root = root
self.train_joint_transform = joint_transforms.Compose([
joint_transforms.JointResize(in_size),
joint_transforms.RandomHorizontallyFlip(),
joint_transforms.RandomRotate(10)
])
self.train_img_transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.1, 0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]) # 處理的是Tensor
])
# ToTensor 操作會將 PIL.Image 或形狀為 H×W×D,數值范圍為 [0, 255] 的 np.ndarray 轉換為形狀為 D×H×W,
# 數值范圍為 [0.0, 1.0] 的 torch.Tensor。
self.train_target_transform = transforms.ToTensor()
self.gt_root = '/home/lart/coding/TIFNet/datasets/DUTSTR/DUTSTR_GT.lmdb'
self.img_root = '/home/lart/coding/TIFNet/datasets/DUTSTR/DUTSTR_IMG.lmdb'
self.paths_gt, self.sizes_gt = _get_paths_from_lmdb(self.gt_root)
self.paths_img, self.sizes_img = _get_paths_from_lmdb(self.img_root)
self.gt_env = lmdb.open(self.gt_root, readonly=True, lock=False, readahead=False,
meminit=False)
self.img_env = lmdb.open(self.img_root, readonly=True, lock=False, readahead=False,
meminit=False)
def __getitem__(self, index):
gt_path = self.paths_gt[index]
img_path = self.paths_img[index]
gt_resolution = [int(s) for s in self.sizes_gt[index].split('_')]
img_resolution = [int(s) for s in self.sizes_img[index].split('_')]
img_gt = _read_img_lmdb(self.gt_env, gt_path, gt_resolution)
img_img = _read_img_lmdb(self.img_env, img_path, img_resolution)
if img_img.shape[-1] != 3:
img_img = np.repeat(img_img, repeats=3, axis=-1)
img_img = img_img[:, :, [2, 1, 0]] # bgr => rgb
img_gt = np.squeeze(img_gt, axis=2)
gt = Image.fromarray(img_gt, mode='L')
img = Image.fromarray(img_img, mode='RGB')
img, gt = self.train_joint_transform(img, gt)
gt = self.train_target_transform(gt)
img = self.train_img_transform(img)
if self.use_bigt:
gt = gt.ge(0.5).float() # 二值化
img_name = self.paths_img[index]
return img, gt, img_name
def __len__(self):
return len(self.paths_img)
class DataLoaderX(DataLoader):
def __iter__(self):
return BackgroundGenerator(super(DataLoaderX, self).__iter__())
