問題背景
訓練深度學習模型往往需要大規模的數據集,這些數據集往往無法直接一次性加載到計算機的內存中,通常需要分批加載。數據的I/O很可能成為訓練深度網絡模型的瓶頸,因此數據的讀取速度對於大規模的數據集(幾十G甚至上千G)是非常關鍵的。例如:https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977
采用數據庫能夠大大提升數據的讀寫速度,例如caffe支持從lmdb、leveldb文件讀取訓練樣本。
lmdb和leveldb的使用方式差不多,Leveldb lmdb性能對比。但是,數據集轉換為LMDB或leveldb之后文件會變大(數據以二進制形式保存),即采用空間換取時間效率。
caffe先支持leveldb,后支持lmdb。lmdb讀取的效率更高,而且支持不同程序同時讀取,而leveldb只允許一個程序讀取。這一點在使用同樣的數據跑不同的配置程序時很重要。
lmdb
參考:https://zhuanlan.zhihu.com/p/70359311
LMDB 全稱為 Lightning Memory-Mapped Database,就是非常快的內存映射型數據庫,LMDB使用內存映射文件,可以提供更好的輸入/輸出性能,對於用於神經網絡的大型數據集( 比如 ImageNet ),可以將其存儲在 LMDB 中。
因為最開始 Caffe 就是使用的這個數據庫,所以網上的大多數關於 LMDB 的教程都通過 Caffe 實現的,對於不了解 Caffe 的同學很不友好,所以本篇文章只講解 LMDB。
LMDB屬於key-value數據庫,而不是關系型數據庫( 比如 MySQL ),LMDB提供 key-value 存儲,其中每個鍵值對都是我們數據集中的一個樣本。LMDB的主要作用是提供數據管理,可以將各種各樣的原始數據轉換為統一的key-value存儲。
LMDB效率高的一個關鍵原因是它是基於內存映射的,這意味着它返回指向鍵和值的內存地址的指針,而不需要像大多數其他數據庫那樣復制內存中的任何內容。
LMDB不僅可以用來存放訓練和測試用的數據集,還可以存放神經網絡提取出的特征數據。如果數據的結構很簡單,就是大量的矩陣和向量,而且數據之間沒有什么關聯,數據內沒有復雜的對象結構,那么就可以選擇LMDB這個簡單的數據庫來存放數據。
LMDB的文件結構很簡單,一個文件夾,里面是一個數據文件和一個鎖文件。數據隨意復制,隨意傳輸。它的訪問簡單,不需要單獨的數據管理進程。只要在訪問代碼里引用LMDB庫,訪問時給文件路徑即可。
用LMDB數據庫來存放圖像數據,而不是直接讀取原始圖像數據的原因:
- 數據類型多種多樣,比如:二進制文件、文本文件、編碼后的圖像文件jpeg、png等,不可能用一套代碼實現所有類型的輸入數據讀取,因此通過LMDB數據庫,轉換為統一數據格式可以簡化數據讀取層的實現。
- lmdb具有極高的存取速度,大大減少了系統訪問大量小文件時的磁盤IO的時間開銷。LMDB將整個數據集都放在一個文件里,避免了文件系統尋址的開銷,你的存儲介質有多快,就能訪問多快,不會因為文件多而導致時間長。LMDB使用了內存映射的方式訪問文件,這使得文件內尋址的開銷大幅度降低。
LMDB 的基本函數
env = lmdb.open()
:創建 lmdb 環境txn = env.begin()
:建立事務txn.put(key, value)
:進行插入和修改txn.delete(key)
:進行刪除txn.get(key)
:進行查詢txn.cursor()
:進行遍歷txn.commit()
:提交更改
創建一個 lmdb 環境:
1 import lmdb 2 3 env = lmdb.open('D:/desktop/lmdb', map_size=10*1024**2)
指定存放生成的lmdb數據庫的文件夾路徑,如果沒有該文件夾則自動創建。
map_size
指定創建的新數據庫所需磁盤空間的最小值,1099511627776B=1T。可以在這里進行 存儲單位換算。
會在指定路徑下創建 data.mdb
和 lock.mdb
兩個文件,一是個數據文件,一個是鎖文件。
修改數據庫內容:
1 # 創建一個事務Transaction對象 2 txn = env.begin(write=True) 3 4 # insert/modify 5 # txn.put(key, value) 6 txn.put(str(1).encode(), "Alice".encode()) # .encode()編碼為字節bytes格式 7 txn.put(str(2).encode(), "Bob".encode()) 8 txn.put(str(3).encode(), "Jack".encode()) 9 10 # delete 11 # txn.delete(key) 12 txn.delete(str(1).encode()) 13 14 # 提交待處理的事務 15 txn.commit()
先創建一個事務(transaction) 對象 txn
,所有的操作都必須經過這個事務對象。因為我們要對數據庫進行寫入操作,所以將 write
參數置為 True
,默認其為 False
。
使用 .put(key, value)
對數據庫進行插入和修改操作,傳入的參數為鍵值對。
值得注意的是,需要在鍵值字符串后加 .encode()
改變其編碼格式,將 str
轉換為 bytes
格式,否則會報該錯誤:TypeError: Won't implicitly convert Unicode to bytes; use .encode()
。在后面使用 .decode()
對其進行解碼得到原數據。
使用 .delete(key)
刪除指定鍵值對。
對LMDB的讀寫操作在事務中執行,需要使用 commit
方法提交待處理的事務。
查詢數據庫內容:
1 # 數據庫查詢 2 txn = env.begin() # 每個commit()之后都需要使用begin()方法更新txn得到最新數據庫 3 4 print(txn.get(str(2).encode())) 5 6 for key, value in txn.cursor(): 7 print(key, value) 8 9 env.close
每次 commit()
之后都要用 env.begin()
更新 txn(得到最新的lmdb數據庫)。
使用 .get(key)
查詢數據庫中的單條記錄。
使用 .cursor() 遍歷數據庫中的所有記錄,其返回一個可迭代對象,相當於關系數據庫中的游標,每讀取一次,游標下移一位。
也可以想文件一樣使用 with
語法:
1 # 可以像文件一樣使用with語法 2 with env.begin() as txn: 3 print(txn.get(str(2).encode())) 4 5 for key, value in txn.cursor(): 6 print(key, value) 7 env.close
完整的demo如下:

1 import lmdb 2 import os, sys 3 4 def initialize(lmdb_dir, map_size): 5 # map_size: bytes 6 env = lmdb.open(lmdb_dir, map_size) 7 return env 8 9 def insert(env, key, value): 10 txn = env.begin(write=True) 11 txn.put(str(key).encode(), value.encode()) 12 txn.commit() 13 14 def delete(env, key): 15 txn = env.begin(write=True) 16 txn.delete(str(key).encode()) 17 txn.commit() 18 19 def update(env, key, value): 20 txn = env.begin(write=True) 21 txn.put(str(key).encode(), value.encode()) 22 txn.commit() 23 24 def search(env, key): 25 txn = env.begin() 26 value = txn.get(str(key).encode()) 27 return value 28 29 def display(env): 30 txn = env.begin() 31 cursor = txn.cursor() 32 for key, value in cursor: 33 print(key, value) 34 35 36 if __name__ == '__main__': 37 path = 'D:/desktop/lmdb' 38 env = initialize(path, 10*1024*1024) 39 40 print("Insert 3 records.") 41 insert(env, 1, "Alice") 42 insert(env, 2, "Bob") 43 insert(env, 3, "Peter") 44 display(env) 45 46 print("Delete the record where key = 1") 47 delete(env, 1) 48 display(env) 49 50 print("Update the record where key = 3") 51 update(env, 3, "Mark") 52 display(env) 53 54 print("Get the value whose key = 3") 55 name = search(env, 3) 56 print(name) 57 58 # 最后需要關閉lmdb數據庫 59 env.close()
圖片數據示例
在圖像深度學習訓練中我們一般都會把大量原始數據集轉化為lmdb格式以方便后續的網絡訓練。因此我們也需要對該數據集進行lmdb格式轉化。
將圖片和對應的文本標簽存放到lmdb數據庫:

1 import lmdb 2 3 image_path = './cat.jpg' 4 label = 'cat' 5 6 env = lmdb.open('lmdb_dir') 7 cache = {} # 存儲鍵值對 8 9 with open(image_path, 'rb') as f: 10 # 讀取圖像文件的二進制格式數據 11 image_bin = f.read() 12 13 # 用兩個鍵值對表示一個數據樣本 14 cache['image_000'] = image_bin 15 cache['label_000'] = label 16 17 with env.begin(write=True) as txn: 18 for k, v in cache.items(): 19 if isinstance(v, bytes): 20 # 圖片類型為bytes 21 txn.put(k.encode(), v) 22 else: 23 # 標簽類型為str, 轉為bytes 24 txn.put(k.encode(), v.encode()) # 編碼 25 26 env.close()
這里需要獲取圖像文件的二進制格式數據,然后用兩個鍵值對保存一個數據樣本,即分開保存圖片和其標簽。
然后分別將圖像和標簽寫入到lmdb數據庫中,和上面例子一樣都需要將鍵值轉換為 bytes
格式。因為此處讀取的圖片格式本身就為 bytes
,所以不需要轉換,標簽格式為 str
,寫入數據庫之前需要先進行編碼將其轉換為 bytes
。
從lmdb數據庫中讀取圖片數據:

1 import cv2 2 import lmdb 3 import numpy as np 4 5 env = lmdb.open('lmdb_dir') 6 7 with env.begin(write=False) as txn: 8 # 獲取圖像數據 9 image_bin = txn.get('image_000'.encode()) 10 label = txn.get('label_000'.encode()).decode() # 解碼 11 12 # 將二進制文件轉為十進制文件(一維數組) 13 image_buf = np.frombuffer(image_bin, dtype=np.uint8) 14 # 將數據轉換(解碼)成圖像格式 15 # cv2.IMREAD_GRAYSCALE為灰度圖,cv2.IMREAD_COLOR為彩色圖 16 img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR) 17 cv2.imshow('image', img) 18 cv2.waitKey(0)
先通過 lmdb.open()
獲取之前創建的lmdb數據庫。
這里通過鍵得到圖片和其標簽,因為寫入數據庫之前進行了編碼,所以這里需要先解碼。
- 標簽通過
.decode()
進行解碼重新得到字符串格式。 - 讀取到的圖片數據為二進制格式,所以先使用
np.frombuffer()
將其轉換為十進制格式的文件,這是一維數組。然后可以使用cv2.imdecode()
將其轉換為灰度圖(二維數組)或者彩色圖(三維數組)。
leveldb
leveldb的使用與lmdb差不多,然而LevelDB 是單進程的服務。
https://www.jianshu.com/p/66496c8726a1
https://github.com/liquidconv/py4db
https://github.com/google/leveldb

1 #!/usr/bin/env python 2 3 import leveldb 4 import os, sys 5 6 def initialize(): 7 db = leveldb.LevelDB("students"); 8 return db; 9 10 def insert(db, sid, name): 11 db.Put(str(sid), name); 12 13 def delete(db, sid): 14 db.Delete(str(sid)); 15 16 def update(db, sid, name): 17 db.Put(str(sid), name); 18 19 def search(db, sid): 20 name = db.Get(str(sid)); 21 return name; 22 23 def display(db): 24 for key, value in db.RangeIter(): 25 print (key, value); 26 27 db = initialize(); 28 29 print "Insert 3 records." 30 insert(db, 1, "Alice"); 31 insert(db, 2, "Bob"); 32 insert(db, 3, "Peter"); 33 display(db); 34 35 print "Delete the record where sid = 1." 36 delete(db, 1); 37 display(db); 38 39 print "Update the record where sid = 3." 40 update(db, 3, "Mark"); 41 display(db); 42 43 print "Get the name of student whose sid = 3." 44 name = search(db, 3); 45 print name;
pytorch從lmdb中加載數據
這里給出一種pytorch從lmdb中加載數據的參考示例,來自:https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977
需要指出的是,pytorch的Dataset並不支持lmdb的迭代器。Dataset通過__getitem__(index)方法通過index獲取一個樣本,因此無法整合lmdb的cursor進行遍歷,只能通過
with self.data_env.begin() as f: data = f.get(key)的方式,即每次打開一個事務txn,這會降低讀取速度。如果設置shuffle=False則可以利用cursor按順序遍歷。
1 from __future__ import print_function 2 import torch.utils.data as data 3 # import h5py 4 import numpy as np 5 import lmdb 6 7 8 class onlineHCCR(data.Dataset): 9 def __init__(self, train=True): 10 # self.root = root 11 self.train = train 12 13 if self.train: 14 datalmdb_path = 'traindata_lmdb' 15 labellmdb_path = 'trainlabel_lmdb' 16 self.data_env = lmdb.open(datalmdb_path, readonly=True) 17 self.label_env = lmdb.open(labellmdb_path, readonly=True) 18 19 else: 20 datalmdb_path = 'testdata_lmdb' 21 labellmdb_path = 'testlabel_lmdb' 22 self.data_env = lmdb.open(datalmdb_path, readonly=True) 23 self.label_env = lmdb.open(labellmdb_path, readonly=True) 24 25 26 def __getitem__(self, index): 27 28 Data = [] 29 Target = [] 30 31 if self.train: 32 with self.data_env.begin() as f: 33 key = '{:08}'.format(index) 34 data = f.get(key) 35 flat_data = np.fromstring(data, dtype=float) 36 data = flat_data.reshape(150, 6).astype('float32') 37 Data = data 38 39 with self.label_env.begin() as f: 40 key = '{:08}'.format(index) 41 data = f.get(key) 42 label = np.fromstring(data, dtype=int) 43 Target = label[0] 44 45 else: 46 47 with self.data_env.begin() as f: 48 key = '{:08}'.format(index) 49 data = f.get(key) 50 flat_data = np.fromstring(data, dtype=float) 51 data = flat_data.reshape(150, 6).astype('float32') 52 Data = data 53 54 with self.label_env.begin() as f: 55 key = '{:08}'.format(index) 56 data = f.get(key) 57 label = np.fromstring(data, dtype=int) 58 Target = label[0] 59 60 return Data, Target 61 62 63 def __len__(self): 64 if self.train: 65 return 2693931 66 else: 67 return 224589
另一個示例:
1 # https://github.com/pytorch/vision/blob/master/torchvision/datasets/lsun.py#L19-L20 2 3 class LSUNClass(VisionDataset): 4 def __init__(self, root, transform=None, target_transform=None): 5 import lmdb 6 super(LSUNClass, self).__init__(root, transform=transform, target_transform=target_transform) 7 8 self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) 9 with self.env.begin(write=False) as txn: 10 self.length = txn.stat()['entries'] 11 cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters) 12 if os.path.isfile(cache_file): 13 self.keys = pickle.load(open(cache_file, "rb")) 14 else: 15 with self.env.begin(write=False) as txn: 16 self.keys = [key for key, _ in txn.cursor()] 17 pickle.dump(self.keys, open(cache_file, "wb")) 18 19 def __getitem__(self, index): 20 img, target = None, None 21 env = self.env 22 with env.begin(write=False) as txn: 23 imgbuf = txn.get(self.keys[index]) 24 25 buf = io.BytesIO() 26 buf.write(imgbuf) 27 buf.seek(0) 28 img = Image.open(buf).convert('RGB') 29 30 if self.transform is not None: 31 img = self.transform(img) 32 33 if self.target_transform is not None: 34 target = self.target_transform(target) 35 36 return img, target 37 38 def __len__(self): 39 return self.length
參考:
Python操作SQLite/MySQL/LMDB/LevelDB
https://github.com/liquidconv/py4db
https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977
https://www.programcreek.com/python/example/106501/lmdb.open
https://realpython.com/storing-images-in-python/
https://www.cnblogs.com/skyfsm/p/10345305.html