faiss是Facebook開源的相似性搜索庫,為稠密向量提供高效相似度搜索和聚類,支持十億級別向量的搜索,是目前最為成熟的近似近鄰搜索庫
faiss不直接提供余弦距離計算,而是提供了歐式距離和點積,利用余弦距離公式,經過L2正則后的向量點積結果即為余弦距離,所以利用faiss計算余弦距離需要先對輸入進行L2正則化
-
安裝
參照官方開源安裝https://github.com/facebookresearch/faiss/blob/main/INSTALL.md
# CPU-only version $ conda install -c pytorch faiss-cpu $ pip install faiss-cpu # GPU(+CPU) version $ conda install -c pytorch faiss-gpu $ pip install faiss-cpu
-
常規計算余弦距離方式
常規一般使用sklearn包的cosine_similarity計算余弦距離,因為該包自動對向量進行L2正則,所以不要求輸入必須為正則結果,代碼如下:
## 計算余弦距離 from sklearn.metrics.pairwise import cosine_similarity from sklearn import preprocessing def get_cos_result(embeding_library, persons, embeding_search): simi = cosine_similarity(embeding_search, embeding_library) max_argmin = np.argmax(simi,axis=1) search_speaker = [[persons[id],simi[i][id]] for i, id in enumerate(max_argmin)] return search_speaker ## 對輸入進行正則化,可以不用正則 def l2_normal(embeding): return preprocessing.normalize(embeding)
-
faiss的精確搜索
faiss並不提供計算與余弦距離,只提供了點積計算和歐式距離,所以在計算余弦距離時,需要對輸入進行L2正則,代碼如下:
import faiss from faiss import normalize_L2 def faiss_precise_search(embeding_library, persons, embeding_search,topk=1): ## 這里也可以使用上文的sklearn的包進行正則 normalize_L2(embeding_search) normalize_L2(embeding_library) # faiss.IndexFlatIP是內積 ;faiss.indexFlatL2是歐式距離 quantizer = faiss.IndexFlatIP(embeding_library.shape[1]) index = quantizer ## 要保證輸入為np.float32格式 index.add(embeding_library.astype(np.float32)) library = {'persons': persons, 'index': index} st = time.time() distance,idx = library['index'].search(embeding_search,topk) print('precise search:',time.time()-st) combined_results = [] for p in range(len(distance)): results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0] combined_results.append(results) return combined_results
-
faiss快速搜索
faiss提供了多種快速搜索的方式,這里介紹常用的一種加速搜索的方式:倒排索引,這種方式與ES快速搜索的方式類似,需要先使用k-means建立聚類中心,通過查詢最近的聚類中心,然后比較聚類中所有向量得到相似向量,這里需要兩個超參數,一個是聚類中心num_cells,一個是查找聚類中心的個數num_cells_in_search,具體代碼如下
def faiss_fast_search(embeding_library, persons, embeding_search,topk=1): normalize_L2(embeding_search) normalize_L2(embeding_library) d = embeding_library.shape[1] num_cells = 50 num_cells_in_search = 5 # 聲明量化器 quantizer = faiss.IndexFlatIP(embeding_library.shape[1]) # faiss.METRIC_INNER_PRODUCT計算內積 faiss.METRIC_L2j計算歐式距離 index = faiss.IndexIVFFlat(quantizer, d,min(num_cells, len(persons)),faiss.METRIC_INNER_PRODUCT) assert not index.is_trained index.train(embeding_library.astype(np.float32)) index.add(embeding_library.astype(np.float32)) index.nprobe = min(num_cells_in_search,len(persons)) library = {'persons': persons, 'index': index} st = time.time() distance, idx = library['index'].search(embeding_search, topk) print('fast search:',time.time()-st) combined_results = [] for p in range(len(distance)): results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0] combined_results.append(results) return combined_results
-
整體代碼
# -*- coding: utf-8 -*- import faiss from faiss import normalize_L2 from sklearn.metrics.pairwise import cosine_similarity from sklearn import preprocessing import numpy as np import time def l2_normal(embeding): return preprocessing.normalize(embeding) def get_cos_result(embeding_search, persons, embeding_library): simi = cosine_similarity(embeding_search, embeding_library) max_argmin = np.argmax(simi,axis=1) search_speaker = [[persons[id],simi[i][id]] for i, id in enumerate(max_argmin)] return search_speaker def faiss_precise_search(embeding_library, persons, embeding_search): normalize_L2(embeding_search) normalize_L2(embeding_library) # faiss.IndexFlatIP是內積 ;faiss.indexFlatL2是歐式距離 quantizer = faiss.IndexFlatIP(embeding_library.shape[1]) index = quantizer index.add(embeding_library.astype(np.float32)) library = {'persons': persons, 'index': index} st = time.time() distance,idx = library['index'].search(embeding_search,1) print('precise search:',time.time()-st) combined_results = [] for p in range(len(distance)): results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0] combined_results.append(results) return combined_results def faiss_fast_search(embeding_library, persons, embeding_search,topk=1): normalize_L2(embeding_search) normalize_L2(embeding_library) num_cells = 500 num_cells_in_search = 10 quantizer = faiss.IndexFlatIP(embeding_library.shape[1]) index = faiss.IndexIVFFlat(quantizer, embeding_library.shape[1],min(num_cells, len(persons)),faiss.METRIC_INNER_PRODUCT) #faiss.METRIC_INNER_PRODUCT計算內積 faiss.METRIC_L2j計算歐式距離 assert not index.is_trained index.train(embeding_library.astype(np.float32)) index.add(embeding_library.astype(np.float32)) index.nprobe = min(num_cells_in_search,len(persons)) library = {'persons': persons, 'index': index} st = time.time() distance, idx = library['index'].search(embeding_search, topk) print('fast search:',time.time()-st) combined_results = [] for p in range(len(distance)): results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0] combined_results.append(results) return combined_results if __name__ == '__main__': d = 512 n_library = 100000 n_search = 1 embeding_library = np.random.random((n_library, d)).astype(np.float32) persons = ['Speak' + "%0d" % (i + 1) for i in range(n_library)] embeding_search = np.random.random((n_search, d)).astype(np.float32) print(faiss_fast_search(embeding_library, persons, embeding_search)) print(faiss_precise_search(embeding_library, persons, embeding_search)) st = time.time() print(get_cos_result(embeding_search, persons, embeding_library)) en1 = time.time() print(en1-st)