LMDB數據庫加速Pytorch文件讀取速度


問題背景

訓練深度學習模型往往需要大規模的數據集,這些數據集往往無法直接一次性加載到計算機的內存中,通常需要分批加載。數據的I/O很可能成為訓練深度網絡模型的瓶頸,因此數據的讀取速度對於大規模的數據集(幾十G甚至上千G)是非常關鍵的。例如:https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977

采用數據庫能夠大大提升數據的讀寫速度,例如caffe支持從lmdb、leveldb文件讀取訓練樣本。

lmdb和leveldb的使用方式差不多,Leveldb lmdb性能對比。但是,數據集轉換為LMDB或leveldb之后文件會變大(數據以二進制形式保存),即采用空間換取時間效率。

caffe先支持leveldb,后支持lmdb。lmdb讀取的效率更高,而且支持不同程序同時讀取,而leveldb只允許一個程序讀取。這一點在使用同樣的數據跑不同的配置程序時很重要。

lmdb

參考:https://zhuanlan.zhihu.com/p/70359311

LMDB 全稱為 Lightning Memory-Mapped Database,就是非常快的內存映射型數據庫,LMDB使用內存映射文件,可以提供更好的輸入/輸出性能,對於用於神經網絡的大型數據集( 比如 ImageNet ),可以將其存儲在 LMDB 中。

因為最開始 Caffe 就是使用的這個數據庫,所以網上的大多數關於 LMDB 的教程都通過 Caffe 實現的,對於不了解 Caffe 的同學很不友好,所以本篇文章只講解 LMDB。

LMDB屬於key-value數據庫,而不是關系型數據庫( 比如 MySQL ),LMDB提供 key-value 存儲,其中每個鍵值對都是我們數據集中的一個樣本。LMDB的主要作用是提供數據管理,可以將各種各樣的原始數據轉換為統一的key-value存儲。

LMDB效率高的一個關鍵原因是它是基於內存映射的,這意味着它返回指向鍵和值的內存地址的指針,而不需要像大多數其他數據庫那樣復制內存中的任何內容。

LMDB不僅可以用來存放訓練和測試用的數據集,還可以存放神經網絡提取出的特征數據。如果數據的結構很簡單,就是大量的矩陣和向量,而且數據之間沒有什么關聯,數據內沒有復雜的對象結構,那么就可以選擇LMDB這個簡單的數據庫來存放數據。

LMDB的文件結構很簡單,一個文件夾,里面是一個數據文件和一個鎖文件。數據隨意復制,隨意傳輸。它的訪問簡單,不需要單獨的數據管理進程。只要在訪問代碼里引用LMDB庫,訪問時給文件路徑即可。

用LMDB數據庫來存放圖像數據,而不是直接讀取原始圖像數據的原因:

  • 數據類型多種多樣,比如:二進制文件、文本文件、編碼后的圖像文件jpeg、png等,不可能用一套代碼實現所有類型的輸入數據讀取,因此通過LMDB數據庫,轉換為統一數據格式可以簡化數據讀取層的實現。
  • lmdb具有極高的存取速度,大大減少了系統訪問大量小文件時的磁盤IO的時間開銷。LMDB將整個數據集都放在一個文件里,避免了文件系統尋址的開銷,你的存儲介質有多快,就能訪問多快,不會因為文件多而導致時間長。LMDB使用了內存映射的方式訪問文件,這使得文件內尋址的開銷大幅度降低。

LMDB 的基本函數

  • env = lmdb.open():創建 lmdb 環境
  • txn = env.begin():建立事務
  • txn.put(key, value):進行插入和修改
  • txn.delete(key):進行刪除
  • txn.get(key):進行查詢
  • txn.cursor():進行遍歷
  • txn.commit():提交更改

創建一個 lmdb 環境:

1 import lmdb
2 
3 env = lmdb.open('D:/desktop/lmdb', map_size=10*1024**2)

 

指定存放生成的lmdb數據庫的文件夾路徑,如果沒有該文件夾則自動創建。

map_size 指定創建的新數據庫所需磁盤空間的最小值,1099511627776B=1T。可以在這里進行 存儲單位換算

會在指定路徑下創建 data.mdb 和 lock.mdb 兩個文件,一是個數據文件,一個是鎖文件。

 

修改數據庫內容:

 1 # 創建一個事務Transaction對象
 2 txn = env.begin(write=True)
 3 
 4 # insert/modify
 5 # txn.put(key, value)
 6 txn.put(str(1).encode(), "Alice".encode()) # .encode()編碼為字節bytes格式
 7 txn.put(str(2).encode(), "Bob".encode())
 8 txn.put(str(3).encode(), "Jack".encode())
 9 
10 # delete
11 # txn.delete(key)
12 txn.delete(str(1).encode())
13 
14 # 提交待處理的事務
15 txn.commit()

先創建一個事務(transaction) 對象 txn,所有的操作都必須經過這個事務對象。因為我們要對數據庫進行寫入操作,所以將 write 參數置為 True,默認其為 False

使用 .put(key, value) 對數據庫進行插入和修改操作,傳入的參數為鍵值對。

值得注意的是,需要在鍵值字符串后加 .encode() 改變其編碼格式,將 str 轉換為 bytes 格式,否則會報該錯誤:TypeError: Won't implicitly convert Unicode to bytes; use .encode()。在后面使用 .decode() 對其進行解碼得到原數據。

使用 .delete(key) 刪除指定鍵值對。

 對LMDB的讀寫操作在事務中執行,需要使用 commit 方法提交待處理的事務。

 

查詢數據庫內容:

1 # 數據庫查詢
2 txn = env.begin() # 每個commit()之后都需要使用begin()方法更新txn得到最新數據庫
3 
4 print(txn.get(str(2).encode()))
5 
6 for key, value in txn.cursor():
7     print(key, value)
8 
9 env.close

每次 commit() 之后都要用 env.begin() 更新 txn(得到最新的lmdb數據庫)。

使用 .get(key) 查詢數據庫中的單條記錄。

使用 .cursor() 遍歷數據庫中的所有記錄,其返回一個可迭代對象,相當於關系數據庫中的游標,每讀取一次,游標下移一位。

 

也可以想文件一樣使用 with 語法:

1 # 可以像文件一樣使用with語法
2 with env.begin() as txn:
3     print(txn.get(str(2).encode()))
4 
5     for key, value in txn.cursor():
6         print(key, value)
7 env.close

 

完整的demo如下:

 1 import lmdb
 2 import os, sys
 3 
 4 def initialize(lmdb_dir, map_size):
 5     # map_size: bytes
 6     env = lmdb.open(lmdb_dir, map_size)
 7     return env
 8 
 9 def insert(env, key, value):
10     txn = env.begin(write=True)
11     txn.put(str(key).encode(), value.encode())
12     txn.commit()
13 
14 def delete(env, key):
15     txn = env.begin(write=True)
16     txn.delete(str(key).encode())
17     txn.commit()
18 
19 def update(env, key, value):
20     txn = env.begin(write=True)
21     txn.put(str(key).encode(), value.encode())
22     txn.commit()
23 
24 def search(env, key):
25     txn = env.begin()
26     value = txn.get(str(key).encode())
27     return value
28 
29 def display(env):
30     txn = env.begin()
31     cursor = txn.cursor()
32     for key, value in cursor:
33         print(key, value)
34 
35 
36 if __name__ == '__main__':
37     path = 'D:/desktop/lmdb'
38     env = initialize(path, 10*1024*1024)
39 
40     print("Insert 3 records.")
41     insert(env, 1, "Alice")
42     insert(env, 2, "Bob")
43     insert(env, 3, "Peter")
44     display(env)
45 
46     print("Delete the record where key = 1")
47     delete(env, 1)
48     display(env)
49 
50     print("Update the record where key = 3")
51     update(env, 3, "Mark")
52     display(env)
53 
54     print("Get the value whose key = 3")
55     name = search(env, 3)
56     print(name)
57 
58     # 最后需要關閉lmdb數據庫
59     env.close()
View Code

圖片數據示例

在圖像深度學習訓練中我們一般都會把大量原始數據集轉化為lmdb格式以方便后續的網絡訓練。因此我們也需要對該數據集進行lmdb格式轉化。

將圖片和對應的文本標簽存放到lmdb數據庫:

 1 import lmdb
 2 
 3 image_path = './cat.jpg'
 4 label = 'cat'
 5 
 6 env = lmdb.open('lmdb_dir')
 7 cache = {}  # 存儲鍵值對
 8 
 9 with open(image_path, 'rb') as f:
10     # 讀取圖像文件的二進制格式數據
11     image_bin = f.read()
12 
13 # 用兩個鍵值對表示一個數據樣本
14 cache['image_000'] = image_bin
15 cache['label_000'] = label
16 
17 with env.begin(write=True) as txn:
18     for k, v in cache.items():
19         if isinstance(v, bytes):
20             # 圖片類型為bytes
21             txn.put(k.encode(), v)
22         else:
23             # 標簽類型為str, 轉為bytes
24             txn.put(k.encode(), v.encode())  # 編碼
25 
26 env.close()
View Code

這里需要獲取圖像文件的二進制格式數據,然后用兩個鍵值對保存一個數據樣本,即分開保存圖片和其標簽。

然后分別將圖像和標簽寫入到lmdb數據庫中,和上面例子一樣都需要將鍵值轉換為 bytes 格式。因為此處讀取的圖片格式本身就為 bytes,所以不需要轉換,標簽格式為 str,寫入數據庫之前需要先進行編碼將其轉換為 bytes

 

從lmdb數據庫中讀取圖片數據:

 1 import cv2
 2 import lmdb
 3 import numpy as np
 4 
 5 env = lmdb.open('lmdb_dir')
 6 
 7 with env.begin(write=False) as txn:
 8     # 獲取圖像數據
 9     image_bin = txn.get('image_000'.encode())
10     label = txn.get('label_000'.encode()).decode()  # 解碼
11 
12     # 將二進制文件轉為十進制文件(一維數組)
13     image_buf = np.frombuffer(image_bin, dtype=np.uint8)
14     # 將數據轉換(解碼)成圖像格式
15     # cv2.IMREAD_GRAYSCALE為灰度圖,cv2.IMREAD_COLOR為彩色圖
16     img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR)
17     cv2.imshow('image', img)
18     cv2.waitKey(0)
View Code

先通過 lmdb.open() 獲取之前創建的lmdb數據庫。

這里通過鍵得到圖片和其標簽,因為寫入數據庫之前進行了編碼,所以這里需要先解碼。

  • 標簽通過 .decode() 進行解碼重新得到字符串格式。
  • 讀取到的圖片數據為二進制格式,所以先使用 np.frombuffer() 將其轉換為十進制格式的文件,這是一維數組。然后可以使用 cv2.imdecode() 將其轉換為灰度圖(二維數組)或者彩色圖(三維數組)。

 

leveldb

leveldb的使用與lmdb差不多,然而LevelDB 是單進程的服務。

https://www.jianshu.com/p/66496c8726a1

https://github.com/liquidconv/py4db

https://github.com/google/leveldb

 1 #!/usr/bin/env python
 2 
 3 import leveldb
 4 import os, sys
 5 
 6 def initialize():
 7     db = leveldb.LevelDB("students");
 8     return db;
 9 
10 def insert(db, sid, name):
11     db.Put(str(sid), name);
12 
13 def delete(db, sid):
14     db.Delete(str(sid));
15 
16 def update(db, sid, name):
17     db.Put(str(sid), name);
18 
19 def search(db, sid):
20     name = db.Get(str(sid));
21     return name;
22 
23 def display(db):
24     for key, value in db.RangeIter():
25         print (key, value);
26 
27 db = initialize();
28 
29 print "Insert 3 records."
30 insert(db, 1, "Alice");
31 insert(db, 2, "Bob");
32 insert(db, 3, "Peter");
33 display(db);
34 
35 print "Delete the record where sid = 1."
36 delete(db, 1);
37 display(db);
38 
39 print "Update the record where sid = 3."
40 update(db, 3, "Mark");
41 display(db);
42 
43 print "Get the name of student whose sid = 3."
44 name = search(db, 3);
45 print name;
View Code

 

 pytorch從lmdb中加載數據

這里給出一種pytorch從lmdb中加載數據的參考示例,來自:https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977

需要指出的是,pytorch的Dataset並不支持lmdb的迭代器。Dataset通過__getitem__(index)方法通過index獲取一個樣本,因此無法整合lmdb的cursor進行遍歷,只能通過

with self.data_env.begin() as f: data = f.get(key)的方式,即每次打開一個事務txn,這會降低讀取速度。如果設置shuffle=False則可以利用cursor按順序遍歷。

 1 from __future__ import print_function
 2 import torch.utils.data as data
 3 # import h5py
 4 import numpy as np
 5 import lmdb
 6 
 7 
 8 class onlineHCCR(data.Dataset):
 9     def __init__(self, train=True):
10         # self.root = root
11         self.train = train
12 
13         if self.train:
14             datalmdb_path = 'traindata_lmdb'
15             labellmdb_path = 'trainlabel_lmdb'
16             self.data_env = lmdb.open(datalmdb_path, readonly=True)
17             self.label_env = lmdb.open(labellmdb_path, readonly=True)
18 
19         else:
20             datalmdb_path = 'testdata_lmdb'
21             labellmdb_path = 'testlabel_lmdb'
22             self.data_env = lmdb.open(datalmdb_path, readonly=True)
23             self.label_env = lmdb.open(labellmdb_path, readonly=True)
24 
25 
26     def __getitem__(self, index):
27 
28         Data = []
29         Target = []
30 
31         if self.train:
32             with self.data_env.begin() as f:
33                 key = '{:08}'.format(index)
34                 data = f.get(key)
35                 flat_data = np.fromstring(data, dtype=float)
36                 data = flat_data.reshape(150, 6).astype('float32')
37                 Data = data
38 
39             with self.label_env.begin() as f:
40                 key = '{:08}'.format(index)
41                 data = f.get(key)
42                 label = np.fromstring(data, dtype=int)
43                 Target = label[0]
44 
45         else:
46 
47             with self.data_env.begin() as f:
48                 key = '{:08}'.format(index)
49                 data = f.get(key)
50                 flat_data = np.fromstring(data, dtype=float)
51                 data = flat_data.reshape(150, 6).astype('float32')
52                 Data = data
53 
54             with self.label_env.begin() as f:
55                 key = '{:08}'.format(index)
56                 data = f.get(key)
57                 label = np.fromstring(data, dtype=int)
58                 Target = label[0]
59 
60         return Data, Target
61         
62 
63     def __len__(self):
64         if self.train:
65             return 2693931
66         else:
67             return 224589

 

 另一個示例:

 1 # https://github.com/pytorch/vision/blob/master/torchvision/datasets/lsun.py#L19-L20
 2 
 3 class LSUNClass(VisionDataset):
 4     def __init__(self, root, transform=None, target_transform=None):
 5         import lmdb
 6         super(LSUNClass, self).__init__(root, transform=transform, target_transform=target_transform)
 7 
 8         self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False)
 9         with self.env.begin(write=False) as txn:
10             self.length = txn.stat()['entries']
11         cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters)
12         if os.path.isfile(cache_file):
13             self.keys = pickle.load(open(cache_file, "rb"))
14         else:
15             with self.env.begin(write=False) as txn:
16                 self.keys = [key for key, _ in txn.cursor()]
17             pickle.dump(self.keys, open(cache_file, "wb"))
18 
19     def __getitem__(self, index):
20         img, target = None, None
21         env = self.env
22         with env.begin(write=False) as txn:
23             imgbuf = txn.get(self.keys[index])
24 
25         buf = io.BytesIO()
26         buf.write(imgbuf)
27         buf.seek(0)
28         img = Image.open(buf).convert('RGB')
29 
30         if self.transform is not None:
31             img = self.transform(img)
32 
33         if self.target_transform is not None:
34             target = self.target_transform(target)
35 
36         return img, target
37 
38     def __len__(self):
39         return self.length

 

 

參考:

lmdb 數據庫

Python操作SQLite/MySQL/LMDB/LevelDB

https://github.com/liquidconv/py4db

https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977

https://www.programcreek.com/python/example/106501/lmdb.open

https://realpython.com/storing-images-in-python/

 https://www.cnblogs.com/skyfsm/p/10345305.html

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM