Java操作ElasticSearch,實現SimHash比較文章相似度


最近工作中要求實現相似文本查詢的功能,我於是決定用SimHash實現。

常規思路通常分為以下四步:

1、實現SimHash算法。

2、保存文章時,同時保存SimHash為倒排索引。

3、入庫時或使用定時任務,在倒排索引中找到碰撞的SimHash,保存為結果表。

4、需要查詢一篇文章的相似文章時,根據文章ID,查詢結果表,找到相似文章。

 

不過這里有個小問題,如果一篇多次入庫的文章的SimHash發生變化,或者文章被刪除啥的,結果表可能很難及時更新。

同時ES剛好很擅長查詢與維護倒排索引,所以我想能不能直接交給ES幫我維護SimHash的倒排索引,從而跳過使用結果表呢?

那么以上邏輯會簡化到3步:

1、實現SimHash算法。

2、保存文章時,同時在ES中保存SimHash字段(和正文其它字段一起)。

3、需要查詢一篇文章的相似文章時,根據文章ID查到SimHash值,再去ES查詢匹配的其它文章ID,不過這里需要在服務層做個漢明距離的過濾

 

說干就干,以下是我的實現代碼,基於網上已有的算法進行了一些修改,總之給大家拋磚引玉了,如果有做的不好的地方還請大家指出。

 

首先添加依賴,使用HanLP分詞,Jsoup提供正文HTML標簽去除服務。

        <dependency>
            <groupId>com.hankcs</groupId>
            <artifactId>hanlp</artifactId>
            <version>portable-1.8.1</version>
        </dependency>

        <dependency>
            <groupId>org.jsoup</groupId>
            <artifactId>jsoup</artifactId>
            <version>1.13.1</version>
        </dependency>

 

接下來是SimHash的核心類,我這里直接寫死了64位SimHash,判重閾值為3:

package com.springboot.text;

import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.dictionary.stopword.CoreStopWordDictionary;
import com.hankcs.hanlp.seg.common.Term;
import com.springboot.commonUtil.StringUtils;

import java.math.BigInteger;
import java.util.List;

/**
 * 提供SimHash相關的計算服務
 */
public class SimHashService {

    public static final BigInteger BIGINT_0 = BigInteger.valueOf(0);
    public static final BigInteger BIGINT_1 = BigInteger.valueOf(1);
    public static final BigInteger BIGINT_2 = BigInteger.valueOf(2);
    public static final BigInteger BIGINT_1000003 = BigInteger.valueOf(1000003);
    public static final BigInteger BIGINT_2E64M1 = BIGINT_2.pow(64).subtract(BIGINT_1);

    /**
     * 計算一段正文的simHash
     * 警告:修改該方法,修改HanLp分詞結果(如新增停用詞),會導致計算出的SimHash發生變化。
     *
     * @param text 需要計算的文本
     * @return 返回simHash,64位的0-1字符串。如果文本過短則返回null。
     */
    public static String get(String text) {
        if (text == null) {
            return null;
        }
        text = StringUtils.removeHtml(text); // return Jsoup.parse(text).text();
        int sumWeight = 0;
        int maxWeight = 0;
        int[] bits = new int[64];
        List<Term> termList = HanLP.segment(text);
        for (Term term : termList) {
            String word = term.word;
            String nature = term.nature.toString();
            if (nature.startsWith("w") || CoreStopWordDictionary.contains(word)) {
                // 去除標點符號和停用詞
                continue;
            }
            BigInteger wordHash = getWordHash(word);
            int wordWeight = getWordWeight(word);
            if (wordWeight == 0) {
                continue;
            }
            sumWeight += wordWeight;
            if (maxWeight < wordWeight) {
                maxWeight = wordWeight;
            }
            // 逐位將計算好的詞哈希乘以權重,記錄到保存用的數組上。
            // 如果該位哈希為1,則加上對應的權重,反之減去對應的權重。
            for (int i = 0; i < 64; i++) {
                BigInteger bitMask = BIGINT_1.shiftLeft(63 - i);
                if (wordHash.and(bitMask).signum() != 0) {
                    bits[i] += wordWeight;
                } else {
                    bits[i] -= wordWeight;
                }
            }
        }
        if (3 * maxWeight >= sumWeight || sumWeight < 20) {
            // 文本太短導致哈希不充分,拒絕返回結果(否則可能會有太多碰撞的文檔,導致查詢性能低下)
            // 暫時定為至少需要湊齊3個大詞才允許返回結果
            return null;
        }

        // 將保存的位統計結果降維,處理成0/1字符串並返回
        StringBuilder simHashBuilder = new StringBuilder();
        for (int i = 0; i < 64; i++) {
            if (bits[i] > 0) {
                simHashBuilder.append("1");
            } else {
                simHashBuilder.append("0");
            }
        }
        return simHashBuilder.toString();
    }

    /**
     * 獲取一個單詞的哈希值
     * 警告:修改該方法會導致計算出的SimHash發生變化。
     *
     * @param word 輸入的單詞
     * @return 返回哈希
     */
    private static BigInteger getWordHash(String word) {
        if (StringUtils.isBlank(word)) {
            return BIGINT_0;
        }
        char[] sourceArray = word.toCharArray();
        // 經過調優,發現左移位數為11-12左右最優
        // 在哈希詞語主要為長度2的中文詞時,可以避免高位哈希出現明顯偏向
        // 反之,如果左移位數太大,則低位哈希將只和詞語最后一個字相關
        BigInteger hash = BigInteger.valueOf(((long) sourceArray[0]) << 12);
        for (char ch : sourceArray) {
            BigInteger chInt = BigInteger.valueOf(ch);
            hash = hash.multiply(BIGINT_1000003).xor(chInt).and(BIGINT_2E64M1);
        }
        hash = hash.xor(BigInteger.valueOf(word.length()));
        return hash;
    }

    /**
     * 獲取一個單詞的權重。
     * 警告:修改該方法會導致計算出的SimHash發生變化。
     *
     * @param word 輸入單詞
     * @return 輸出權重
     */
    private static int getWordWeight(String word) {
        if (StringUtils.isBlank(word)) {
            return 0;
        }
        int length = word.length();
        if (length == 1) {
            // 只有長度為1的詞,哈希后位數不夠(40位左右),所以權重必須很低,否則容易導致高位哈希全部為0。
            return 1;
        } else if (word.charAt(0) >= 0x3040) {
            if (length == 2) {
                return 8;
            } else {
                return 16;
            }
        } else {
            if (length == 2) {
                return 2;
            } else {
                return 4;
            }
        }
    }

    /**
     * 截取SimHash的一部分,轉換為short對象
     *
     * @param simHash 原始SimHash字符串,64位0/1字符
     * @param part    需要截取的部分編號
     * @return 返回Short值
     */
    public static Short toShort(String simHash, int part) {
        if (simHash == null || part < 0 || part > 3) {
            return null;
        }
        int startBit = part * 16;
        int endBit = (part + 1) * 16;
        return Integer.valueOf(simHash.substring(startBit, endBit), 2).shortValue();
    }

    /**
     * 將四段Short格式的SimHash拼接成字符串
     *
     * @param simHashA simHashA,最高位
     * @param simHashB simHashB
     * @param simHashC simHashC
     * @param simHashD simHashD,最低位
     * @return 返回64位0/1格式的SimHash
     */
    public static String toSimHash(Short simHashA, Short simHashB, Short simHashC, Short simHashD) {
        return toSimHash(simHashA) + toSimHash(simHashB) + toSimHash(simHashC) + toSimHash(simHashD);
    }

    /**
     * 將一段Short格式的SimHash拼接成字符串
     *
     * @param simHashX 需要轉換的Short格式SimHash
     * @return 返回16位0/1格式的SimHash
     */
    public static String toSimHash(Short simHashX) {
        StringBuilder simHashBuilder = new StringBuilder(Integer.toString(simHashX & 65535, 2));
        int fill0Count = 16 - simHashBuilder.length();
        for (int i = 0; i < fill0Count; i++) {
            simHashBuilder.insert(0, "0");
        }
        return simHashBuilder.toString();
    }

    /**
     * 比較兩組SimHash(一組為64位0/1字符串,一組為四組Short),計算漢明距離
     *
     * @param simHash  待比較的SimHash(X),64位0/1字符串
     * @param simHashA 待比較的SimHash(Y),Short格式,最高位
     * @param simHashB 待比較的SimHash(Y),Short格式
     * @param simHashC 待比較的SimHash(Y),Short格式
     * @param simHashD 待比較的SimHash(Y),Short格式,最低位
     * @return 返回漢明距離
     */
    public static int hammingDistance(String simHash, Short simHashA, Short simHashB, Short simHashC, Short simHashD) {
        if (simHash == null || simHashA == null || simHashB == null || simHashC == null || simHashD == null) {
            return -1;
        }
        int hammingDistance = 0;
        for (int part = 0; part < 4; part++) {
            Short simHashX = toShort(simHash, part);
            Short simHashY = null;
            switch (part) {
                case 0:
                    simHashY = simHashA;
                    break;
                case 1:
                    simHashY = simHashB;
                    break;
                case 2:
                    simHashY = simHashC;
                    break;
                case 3:
                    simHashY = simHashD;
                    break;
            }
            hammingDistance += hammingDistance(simHashX, simHashY);
        }
        return hammingDistance;
    }

    /**
     * 比較兩個Short格式的SimHash的漢明距離
     *
     * @param simHashX 待比較的SimHashX
     * @param simHashY 待比較的SimHashY
     * @return 返回漢明距離
     */
    public static int hammingDistance(Short simHashX, Short simHashY) {
        if (simHashX == null || simHashY == null) {
            return -1;
        }
        int hammingDistance = 0;
        int xorResult = (simHashX ^ simHashY) & 65535;

        while (xorResult != 0) {
            xorResult = xorResult & (xorResult - 1);
            hammingDistance += 1;
        }
        return hammingDistance;
    }

}

 

ES索引中需要新增4個SimHash相關的字段:

      "simHashA": {
        "type": "short"
      },
      "simHashB": {
        "type": "short"
      },
      "simHashC": {
        "type": "short"
      },
      "simHashD": {
        "type": "short"
      }

 

最后是ES查詢邏輯,根據傳入的SimHash,先使用ES找到至少一組SimHash相等的文檔,然后在Java代碼中比較剩下三組是否滿足要求。

    /**
     * 根據SimHash,查詢相似的文章。
     *
     * @param indexNames 需要查詢的索引名稱(允許多個)
     * @param simHashA   simHashA的值
     * @param simHashB   simHashB的值
     * @param simHashC   simHashC的值
     * @param simHashD   simHashD的值
     * @return 返回相似文章RowKey列表。
     */
    public List<String> searchBySimHash(String indexNames, Short simHashA, short simHashB, short simHashC, short simHashD) {
        String simHash = SimHashService.toSimHash(simHashA, simHashB, simHashC, simHashD);
        return searchBySimHash(indexNames, simHash);
    }

    /**
     * 根據SimHash,查詢相似的文章。
     *
     * @param indexNames 需要查詢的索引名稱(允許多個)
     * @param simHash    需要查詢的SimHash (格式:64位二進制字符串)
     * @return 返回相似文章RowKey列表。
     */
    public List<String> searchBySimHash(String indexNames, String simHash) {
        List<String> resultList = new ArrayList<>();
        if (simHash == null) {
            return resultList;
        }
        try {
            String scrollId = "";
            while (true) {
                if (scrollId == null) {
                    break;
                }
                SearchResponse response = null;
                if (scrollId.isEmpty()) {
                    // 首次請求,正常查詢
                    SearchRequest request = new SearchRequest(indexNames.split(","));
                    BoolQueryBuilder bqBuilder = QueryBuilders.boolQuery();
                    bqBuilder.should(QueryBuilders.termQuery("simHashA", SimHashService.toShort(simHash, 0)));
                    bqBuilder.should(QueryBuilders.termQuery("simHashB", SimHashService.toShort(simHash, 1)));
                    bqBuilder.should(QueryBuilders.termQuery("simHashC", SimHashService.toShort(simHash, 2)));
                    bqBuilder.should(QueryBuilders.termQuery("simHashD", SimHashService.toShort(simHash, 3)));
                    SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(10000);
                    sourceBuilder.query(bqBuilder);
                    sourceBuilder.from(0);
                    sourceBuilder.size(10000);
                    sourceBuilder.timeout(TimeValue.timeValueSeconds(60));
                    sourceBuilder.fetchSource(new String[]{"hId", "simHashA", "simHashB", "simHashC", "simHashD"}, new String[]{});
                    sourceBuilder.sort("publishDate", SortOrder.DESC);
                    request.source(sourceBuilder);
                    request.scroll(TimeValue.timeValueSeconds(60));
                    response = client.search(request, RequestOptions.DEFAULT);
                } else {
                    // 之后請求,走游標查詢
                    SearchScrollRequest searchScrollRequest = new SearchScrollRequest(scrollId).scroll(TimeValue.timeValueMinutes(10));
                    response = client.scroll(searchScrollRequest, RequestOptions.DEFAULT);
                }
                if (response != null && response.getHits().getHits().length > 0) {
                    // 查到的記錄必然有一組simHashX與輸入相同,但需要合並確認總數是否小於閾值
                    // 很可能有幾萬的命中,但最終過濾完只剩下幾條數據,最終留下ID
                    for (SearchHit hit : response.getHits().getHits()) {
                        Map<String, Object> sourceAsMap = hit.getSourceAsMap();
                        String hId = String.valueOf(sourceAsMap.get("hId"));
                        Short simHashA = Short.parseShort(sourceAsMap.get("simHashA").toString());
                        Short simHashB = Short.parseShort(sourceAsMap.get("simHashB").toString());
                        Short simHashC = Short.parseShort(sourceAsMap.get("simHashC").toString());
                        Short simHashD = Short.parseShort(sourceAsMap.get("simHashD").toString());
                        int hammingDistance = SimHashService.hammingDistance(simHash, simHashA, simHashB, simHashC, simHashD);
                        if (hammingDistance < 4) {
                            System.out.println(hammingDistance + "\t" + hId);
                            resultList.add(sourceAsMap.get("hId").toString());
                        }
                    }
                    scrollId = response.getScrollId();
                } else {
                    break;
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return resultList;
    }

 

目前在ES單節點保存90萬條數據(其中10萬含有SimHash字段)的查詢延遲大約在0.2秒左右。

總之我把我的思路分享給大家,可能我代碼寫的比較爛,還請大家指點。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM