faiss計算余弦距離


faiss是Facebook開源的相似性搜索庫,為稠密向量提供高效相似度搜索和聚類,支持十億級別向量的搜索,是目前最為成熟的近似近鄰搜索庫

faiss不直接提供余弦距離計算,而是提供了歐式距離和點積,利用余弦距離公式,經過L2正則后的向量點積結果即為余弦距離,所以利用faiss計算余弦距離需要先對輸入進行L2正則化

  • 安裝

    參照官方開源安裝https://github.com/facebookresearch/faiss/blob/main/INSTALL.md

    # CPU-only version
    $ conda install -c pytorch faiss-cpu
    $ pip install faiss-cpu
    
    # GPU(+CPU) version
    $ conda install -c pytorch faiss-gpu
    $ pip install faiss-cpu
     
    
  • 常規計算余弦距離方式

    常規一般使用sklearn包的cosine_similarity計算余弦距離,因為該包自動對向量進行L2正則,所以不要求輸入必須為正則結果,代碼如下:

    ## 計算余弦距離
    from sklearn.metrics.pairwise import cosine_similarity
    from sklearn import preprocessing
    def get_cos_result(embeding_library, persons, embeding_search):
        simi = cosine_similarity(embeding_search, embeding_library)
        max_argmin = np.argmax(simi,axis=1)
        search_speaker = [[persons[id],simi[i][id]] for i, id in enumerate(max_argmin)]
        return search_speaker
    ## 對輸入進行正則化,可以不用正則
    def l2_normal(embeding):
        return preprocessing.normalize(embeding)
    
  • faiss的精確搜索

    faiss並不提供計算與余弦距離,只提供了點積計算和歐式距離,所以在計算余弦距離時,需要對輸入進行L2正則,代碼如下:

    import faiss
    from faiss import normalize_L2
    def faiss_precise_search(embeding_library, persons, embeding_search,topk=1):
        ## 這里也可以使用上文的sklearn的包進行正則
        normalize_L2(embeding_search)
        normalize_L2(embeding_library)
        # faiss.IndexFlatIP是內積 ;faiss.indexFlatL2是歐式距離
        quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
        index = quantizer
        ## 要保證輸入為np.float32格式
        index.add(embeding_library.astype(np.float32))
        library = {'persons': persons, 'index': index}
        st = time.time()
        distance,idx = library['index'].search(embeding_search,topk)
        print('precise search:',time.time()-st)
        combined_results = []
        for p in range(len(distance)):
            results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
            combined_results.append(results)
        return combined_results
    
  • faiss快速搜索

    faiss提供了多種快速搜索的方式,這里介紹常用的一種加速搜索的方式:倒排索引,這種方式與ES快速搜索的方式類似,需要先使用k-means建立聚類中心,通過查詢最近的聚類中心,然后比較聚類中所有向量得到相似向量,這里需要兩個超參數,一個是聚類中心num_cells,一個是查找聚類中心的個數num_cells_in_search,具體代碼如下

    def faiss_fast_search(embeding_library, persons, embeding_search,topk=1):
        normalize_L2(embeding_search)
        normalize_L2(embeding_library)
        d = embeding_library.shape[1]
        num_cells = 50
        num_cells_in_search = 5
        # 聲明量化器
        quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
         # faiss.METRIC_INNER_PRODUCT計算內積 faiss.METRIC_L2j計算歐式距離
        index = faiss.IndexIVFFlat(quantizer, d,min(num_cells, len(persons)),faiss.METRIC_INNER_PRODUCT)
        assert not index.is_trained
        index.train(embeding_library.astype(np.float32))
        index.add(embeding_library.astype(np.float32))
        index.nprobe = min(num_cells_in_search,len(persons))
        library = {'persons': persons, 'index': index}
        st = time.time()
        distance, idx = library['index'].search(embeding_search, topk)
        print('fast search:',time.time()-st)
        combined_results = []
        for p in range(len(distance)):
            results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
            combined_results.append(results)
        return combined_results
    
  • 整體代碼

    # -*- coding: utf-8 -*-
    import faiss
    from faiss import normalize_L2
    from sklearn.metrics.pairwise import cosine_similarity
    from sklearn import preprocessing
    import numpy as np
    import time
    
    def l2_normal(embeding):
        return preprocessing.normalize(embeding)
    
    def get_cos_result(embeding_search, persons, embeding_library):
        simi = cosine_similarity(embeding_search, embeding_library)
        max_argmin = np.argmax(simi,axis=1)
        search_speaker = [[persons[id],simi[i][id]] for i, id in enumerate(max_argmin)]
        return search_speaker
    
    def faiss_precise_search(embeding_library, persons, embeding_search):
        normalize_L2(embeding_search)
        normalize_L2(embeding_library)
        # faiss.IndexFlatIP是內積 ;faiss.indexFlatL2是歐式距離
        quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
        index = quantizer
        index.add(embeding_library.astype(np.float32))
        library = {'persons': persons, 'index': index}
        st = time.time()
        distance,idx = library['index'].search(embeding_search,1)
        print('precise search:',time.time()-st)
        combined_results = []
        for p in range(len(distance)):
            results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
            combined_results.append(results)
        return combined_results
    
    def faiss_fast_search(embeding_library, persons, embeding_search,topk=1):
        normalize_L2(embeding_search)
        normalize_L2(embeding_library)
        num_cells = 500
        num_cells_in_search = 10
        quantizer = faiss.IndexFlatIP(embeding_library.shape[1])
        index = faiss.IndexIVFFlat(quantizer, embeding_library.shape[1],min(num_cells, len(persons)),faiss.METRIC_INNER_PRODUCT) #faiss.METRIC_INNER_PRODUCT計算內積 faiss.METRIC_L2j計算歐式距離
        assert not index.is_trained
        index.train(embeding_library.astype(np.float32))
        index.add(embeding_library.astype(np.float32))
        index.nprobe = min(num_cells_in_search,len(persons))
        library = {'persons': persons, 'index': index}
        st = time.time()
        distance, idx = library['index'].search(embeding_search, topk)
        print('fast search:',time.time()-st)
        combined_results = []
        for p in range(len(distance)):
            results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0]
            combined_results.append(results)
        return combined_results
    
    if __name__ == '__main__':
        d = 512
        n_library = 100000
        n_search = 1
        embeding_library = np.random.random((n_library, d)).astype(np.float32)
        persons = ['Speak' + "%0d" % (i + 1) for i in range(n_library)]
        embeding_search = np.random.random((n_search, d)).astype(np.float32)
        print(faiss_fast_search(embeding_library, persons, embeding_search))
        print(faiss_precise_search(embeding_library, persons, embeding_search))
        st = time.time()
        print(get_cos_result(embeding_search, persons, embeding_library))
        en1 = time.time()
        print(en1-st)
    


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM