最近工作中要求實現相似文本查詢的功能,我於是決定用SimHash實現。
常規思路通常分為以下四步:
1、實現SimHash算法。
2、保存文章時,同時保存SimHash為倒排索引。
3、入庫時或使用定時任務,在倒排索引中找到碰撞的SimHash,保存為結果表。
4、需要查詢一篇文章的相似文章時,根據文章ID,查詢結果表,找到相似文章。
不過這里有個小問題,如果一篇多次入庫的文章的SimHash發生變化,或者文章被刪除啥的,結果表可能很難及時更新。
同時ES剛好很擅長查詢與維護倒排索引,所以我想能不能直接交給ES幫我維護SimHash的倒排索引,從而跳過使用結果表呢?
那么以上邏輯會簡化到3步:
1、實現SimHash算法。
2、保存文章時,同時在ES中保存SimHash字段(和正文其它字段一起)。
3、需要查詢一篇文章的相似文章時,根據文章ID查到SimHash值,再去ES查詢匹配的其它文章ID,不過這里需要在服務層做個漢明距離的過濾。
說干就干,以下是我的實現代碼,基於網上已有的算法進行了一些修改,總之給大家拋磚引玉了,如果有做的不好的地方還請大家指出。
首先添加依賴,使用HanLP分詞,Jsoup提供正文HTML標簽去除服務。
<dependency> <groupId>com.hankcs</groupId> <artifactId>hanlp</artifactId> <version>portable-1.8.1</version> </dependency> <dependency> <groupId>org.jsoup</groupId> <artifactId>jsoup</artifactId> <version>1.13.1</version> </dependency>
接下來是SimHash的核心類,我這里直接寫死了64位SimHash,判重閾值為3:
package com.springboot.text; import com.hankcs.hanlp.HanLP; import com.hankcs.hanlp.dictionary.stopword.CoreStopWordDictionary; import com.hankcs.hanlp.seg.common.Term; import com.springboot.commonUtil.StringUtils; import java.math.BigInteger; import java.util.List; /** * 提供SimHash相關的計算服務 */ public class SimHashService { public static final BigInteger BIGINT_0 = BigInteger.valueOf(0); public static final BigInteger BIGINT_1 = BigInteger.valueOf(1); public static final BigInteger BIGINT_2 = BigInteger.valueOf(2); public static final BigInteger BIGINT_1000003 = BigInteger.valueOf(1000003); public static final BigInteger BIGINT_2E64M1 = BIGINT_2.pow(64).subtract(BIGINT_1); /** * 計算一段正文的simHash * 警告:修改該方法,修改HanLp分詞結果(如新增停用詞),會導致計算出的SimHash發生變化。 * * @param text 需要計算的文本 * @return 返回simHash,64位的0-1字符串。如果文本過短則返回null。 */ public static String get(String text) { if (text == null) { return null; } text = StringUtils.removeHtml(text); // return Jsoup.parse(text).text(); int sumWeight = 0; int maxWeight = 0; int[] bits = new int[64]; List<Term> termList = HanLP.segment(text); for (Term term : termList) { String word = term.word; String nature = term.nature.toString(); if (nature.startsWith("w") || CoreStopWordDictionary.contains(word)) { // 去除標點符號和停用詞 continue; } BigInteger wordHash = getWordHash(word); int wordWeight = getWordWeight(word); if (wordWeight == 0) { continue; } sumWeight += wordWeight; if (maxWeight < wordWeight) { maxWeight = wordWeight; } // 逐位將計算好的詞哈希乘以權重,記錄到保存用的數組上。 // 如果該位哈希為1,則加上對應的權重,反之減去對應的權重。 for (int i = 0; i < 64; i++) { BigInteger bitMask = BIGINT_1.shiftLeft(63 - i); if (wordHash.and(bitMask).signum() != 0) { bits[i] += wordWeight; } else { bits[i] -= wordWeight; } } } if (3 * maxWeight >= sumWeight || sumWeight < 20) { // 文本太短導致哈希不充分,拒絕返回結果(否則可能會有太多碰撞的文檔,導致查詢性能低下) // 暫時定為至少需要湊齊3個大詞才允許返回結果 return null; } // 將保存的位統計結果降維,處理成0/1字符串並返回 StringBuilder simHashBuilder = new StringBuilder(); for (int i = 0; i < 64; i++) { if (bits[i] > 0) { simHashBuilder.append("1"); } else { simHashBuilder.append("0"); } } return simHashBuilder.toString(); } /** * 獲取一個單詞的哈希值 * 警告:修改該方法會導致計算出的SimHash發生變化。 * * @param word 輸入的單詞 * @return 返回哈希 */ private static BigInteger getWordHash(String word) { if (StringUtils.isBlank(word)) { return BIGINT_0; } char[] sourceArray = word.toCharArray(); // 經過調優,發現左移位數為11-12左右最優 // 在哈希詞語主要為長度2的中文詞時,可以避免高位哈希出現明顯偏向 // 反之,如果左移位數太大,則低位哈希將只和詞語最后一個字相關 BigInteger hash = BigInteger.valueOf(((long) sourceArray[0]) << 12); for (char ch : sourceArray) { BigInteger chInt = BigInteger.valueOf(ch); hash = hash.multiply(BIGINT_1000003).xor(chInt).and(BIGINT_2E64M1); } hash = hash.xor(BigInteger.valueOf(word.length())); return hash; } /** * 獲取一個單詞的權重。 * 警告:修改該方法會導致計算出的SimHash發生變化。 * * @param word 輸入單詞 * @return 輸出權重 */ private static int getWordWeight(String word) { if (StringUtils.isBlank(word)) { return 0; } int length = word.length(); if (length == 1) { // 只有長度為1的詞,哈希后位數不夠(40位左右),所以權重必須很低,否則容易導致高位哈希全部為0。 return 1; } else if (word.charAt(0) >= 0x3040) { if (length == 2) { return 8; } else { return 16; } } else { if (length == 2) { return 2; } else { return 4; } } } /** * 截取SimHash的一部分,轉換為short對象 * * @param simHash 原始SimHash字符串,64位0/1字符 * @param part 需要截取的部分編號 * @return 返回Short值 */ public static Short toShort(String simHash, int part) { if (simHash == null || part < 0 || part > 3) { return null; } int startBit = part * 16; int endBit = (part + 1) * 16; return Integer.valueOf(simHash.substring(startBit, endBit), 2).shortValue(); } /** * 將四段Short格式的SimHash拼接成字符串 * * @param simHashA simHashA,最高位 * @param simHashB simHashB * @param simHashC simHashC * @param simHashD simHashD,最低位 * @return 返回64位0/1格式的SimHash */ public static String toSimHash(Short simHashA, Short simHashB, Short simHashC, Short simHashD) { return toSimHash(simHashA) + toSimHash(simHashB) + toSimHash(simHashC) + toSimHash(simHashD); } /** * 將一段Short格式的SimHash拼接成字符串 * * @param simHashX 需要轉換的Short格式SimHash * @return 返回16位0/1格式的SimHash */ public static String toSimHash(Short simHashX) { StringBuilder simHashBuilder = new StringBuilder(Integer.toString(simHashX & 65535, 2)); int fill0Count = 16 - simHashBuilder.length(); for (int i = 0; i < fill0Count; i++) { simHashBuilder.insert(0, "0"); } return simHashBuilder.toString(); } /** * 比較兩組SimHash(一組為64位0/1字符串,一組為四組Short),計算漢明距離 * * @param simHash 待比較的SimHash(X),64位0/1字符串 * @param simHashA 待比較的SimHash(Y),Short格式,最高位 * @param simHashB 待比較的SimHash(Y),Short格式 * @param simHashC 待比較的SimHash(Y),Short格式 * @param simHashD 待比較的SimHash(Y),Short格式,最低位 * @return 返回漢明距離 */ public static int hammingDistance(String simHash, Short simHashA, Short simHashB, Short simHashC, Short simHashD) { if (simHash == null || simHashA == null || simHashB == null || simHashC == null || simHashD == null) { return -1; } int hammingDistance = 0; for (int part = 0; part < 4; part++) { Short simHashX = toShort(simHash, part); Short simHashY = null; switch (part) { case 0: simHashY = simHashA; break; case 1: simHashY = simHashB; break; case 2: simHashY = simHashC; break; case 3: simHashY = simHashD; break; } hammingDistance += hammingDistance(simHashX, simHashY); } return hammingDistance; } /** * 比較兩個Short格式的SimHash的漢明距離 * * @param simHashX 待比較的SimHashX * @param simHashY 待比較的SimHashY * @return 返回漢明距離 */ public static int hammingDistance(Short simHashX, Short simHashY) { if (simHashX == null || simHashY == null) { return -1; } int hammingDistance = 0; int xorResult = (simHashX ^ simHashY) & 65535; while (xorResult != 0) { xorResult = xorResult & (xorResult - 1); hammingDistance += 1; } return hammingDistance; } }
ES索引中需要新增4個SimHash相關的字段:
"simHashA": { "type": "short" }, "simHashB": { "type": "short" }, "simHashC": { "type": "short" }, "simHashD": { "type": "short" }
最后是ES查詢邏輯,根據傳入的SimHash,先使用ES找到至少一組SimHash相等的文檔,然后在Java代碼中比較剩下三組是否滿足要求。
/** * 根據SimHash,查詢相似的文章。 * * @param indexNames 需要查詢的索引名稱(允許多個) * @param simHashA simHashA的值 * @param simHashB simHashB的值 * @param simHashC simHashC的值 * @param simHashD simHashD的值 * @return 返回相似文章RowKey列表。 */ public List<String> searchBySimHash(String indexNames, Short simHashA, short simHashB, short simHashC, short simHashD) { String simHash = SimHashService.toSimHash(simHashA, simHashB, simHashC, simHashD); return searchBySimHash(indexNames, simHash); } /** * 根據SimHash,查詢相似的文章。 * * @param indexNames 需要查詢的索引名稱(允許多個) * @param simHash 需要查詢的SimHash (格式:64位二進制字符串) * @return 返回相似文章RowKey列表。 */ public List<String> searchBySimHash(String indexNames, String simHash) { List<String> resultList = new ArrayList<>(); if (simHash == null) { return resultList; } try { String scrollId = ""; while (true) { if (scrollId == null) { break; } SearchResponse response = null; if (scrollId.isEmpty()) { // 首次請求,正常查詢 SearchRequest request = new SearchRequest(indexNames.split(",")); BoolQueryBuilder bqBuilder = QueryBuilders.boolQuery(); bqBuilder.should(QueryBuilders.termQuery("simHashA", SimHashService.toShort(simHash, 0))); bqBuilder.should(QueryBuilders.termQuery("simHashB", SimHashService.toShort(simHash, 1))); bqBuilder.should(QueryBuilders.termQuery("simHashC", SimHashService.toShort(simHash, 2))); bqBuilder.should(QueryBuilders.termQuery("simHashD", SimHashService.toShort(simHash, 3))); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(10000); sourceBuilder.query(bqBuilder); sourceBuilder.from(0); sourceBuilder.size(10000); sourceBuilder.timeout(TimeValue.timeValueSeconds(60)); sourceBuilder.fetchSource(new String[]{"hId", "simHashA", "simHashB", "simHashC", "simHashD"}, new String[]{}); sourceBuilder.sort("publishDate", SortOrder.DESC); request.source(sourceBuilder); request.scroll(TimeValue.timeValueSeconds(60)); response = client.search(request, RequestOptions.DEFAULT); } else { // 之后請求,走游標查詢 SearchScrollRequest searchScrollRequest = new SearchScrollRequest(scrollId).scroll(TimeValue.timeValueMinutes(10)); response = client.scroll(searchScrollRequest, RequestOptions.DEFAULT); } if (response != null && response.getHits().getHits().length > 0) { // 查到的記錄必然有一組simHashX與輸入相同,但需要合並確認總數是否小於閾值 // 很可能有幾萬的命中,但最終過濾完只剩下幾條數據,最終留下ID for (SearchHit hit : response.getHits().getHits()) { Map<String, Object> sourceAsMap = hit.getSourceAsMap(); String hId = String.valueOf(sourceAsMap.get("hId")); Short simHashA = Short.parseShort(sourceAsMap.get("simHashA").toString()); Short simHashB = Short.parseShort(sourceAsMap.get("simHashB").toString()); Short simHashC = Short.parseShort(sourceAsMap.get("simHashC").toString()); Short simHashD = Short.parseShort(sourceAsMap.get("simHashD").toString()); int hammingDistance = SimHashService.hammingDistance(simHash, simHashA, simHashB, simHashC, simHashD); if (hammingDistance < 4) { System.out.println(hammingDistance + "\t" + hId); resultList.add(sourceAsMap.get("hId").toString()); } } scrollId = response.getScrollId(); } else { break; } } } catch (IOException e) { e.printStackTrace(); } return resultList; }
目前在ES單節點保存90萬條數據(其中10萬含有SimHash字段)的查詢延遲大約在0.2秒左右。
總之我把我的思路分享給大家,可能我代碼寫的比較爛,還請大家指點。