python利用Trie(前綴樹)實現搜索引擎中關鍵字輸入提示(學習Hash Trie和Double-array Trie)
主要包括兩部分內容:
(1)利用python中的dict實現Trie;
(2)按照darts-java的方法做python的實現Double-array Trie
比較:
(1)的實現相對簡單,但在詞典較大時,時間復雜度較高
(2)Double-array Trie是Trie高效實現,時間復雜度達到O(n),但是實現相對較難
最近遇到一個問題,希望對地名檢索時,根據用戶的輸入,實時推薦用戶可能檢索的候選地名,並根據實時熱度進行排序。這可以以歸納為一個Trie(前綴樹)問題。
Trie在自然語言處理中非常常用,可以實現文本的快速分詞、詞頻統計、字符串查詢和模糊匹配、字符串排序、關鍵輸入提示、關鍵字糾錯等場景中。
這些問題都可以在單詞樹/前綴樹/Trie來解決,關於Trie的介紹看【小白詳解 Trie 樹】這篇文章就夠了
一、Hash實現Trie(python中的dict)
github上有Trie實現關鍵字,實現Trie樹的新增、刪除、查找,並根據熱度CACHED_THREHOLD在node節點對后綴進行緩存,以便提高對高頻詞的檢索效率。本人在其代碼上做了注解。
並對其進行了測試,測試的數據包括了兩列,包括關鍵詞和頻次。
【code】
#!/usr/bin/env python # encoding: utf-8 """ @date: 20131001 @version: 0.2 @author: wklken@yeah.net @desc: 搜索下拉提示,基於后台提供數據,建立數據結構(前綴樹),用戶輸入query前綴時,可以提示對應query前綴補全 @update: 20131001 基本結構,新增,搜索等基本功能 20131005 增加緩存功能,當緩存打開,用戶搜索某個前綴超過一定次數時,進行緩存,減少搜索時間 20140309 修改代碼,降低內存占用 @TODO: test case 加入拼音的話,導致內存占用翻倍增長,要考慮下如何優化節點,共用內存 """ #這是實現cache的一種方式,也可以使用redis/memcached在外部做緩存 #https://github.com/wklken/suggestion/blob/master/easymap/suggest.py #一旦打開,search時會對每個節點做cache,當增加刪除節點時,其路徑上的cache會被清除,搜索時間降低了一個數量級 #代價:內存消耗, 不需要時可以關閉,或者通過CACHED_THREHOLD調整緩存數量 #開啟 #CACHED = True #關閉 CACHED = False #注意,CACHED_SIZE >= search中的limit,保證search從緩存能獲取到足夠多的結果 CACHED_SIZE = 10 #被搜索超過多少次后才加入緩存 CACHED_THREHOLD = 10 ############### start ###################### class Node(dict): def __init__(self, key, is_leaf=False, weight=0, kwargs=None): """ @param key: 節點字符 @param is_leaf: 是否葉子節點 @param weight: 節點權重, 某個詞最后一個字節點代表其權重,其余中間節點權重為0,無意義 @param kwargs: 可傳入其他任意參數,用於某些特殊用途 """ self.key = key self.is_leaf = is_leaf self.weight = weight #緩存,存的是node指針 self.cache = [] #節點前綴搜索次數,可以用於搜索query數據分析 self.search_count = 0 #其他節點無關僅和內容相關的參數 if kwargs: for key, value in kwargs.iteritems(): setattr(self, key, value) def __str__(self): return '<Node key:%s is_leaf:%s weight:%s Subnodes: %s>' % (self.key, self.is_leaf, self.weight, self.items()) def add_subnode(self, node): """ 添加子節點 :param node: 子節點對象 """ self.update({node.key: node}) def get_subnode(self, key): """ 獲取子節點 :param key: 子節點key :return: Node對象 """ return self.get(key) def has_subnode(self): """ 判斷是否存在子節點 :return: bool """ return len(self) > 0 def get_top_node(self, prefix): """ 獲取一個前綴的最后一個節點(補全所有后綴的頂部節點) :param prefix: 字符轉前綴 :return: Node對象 """ top = self for k in prefix: top = top.get_subnode(k) if top is None: return None return top def depth_walk(node): """ 遞歸,深度優先遍歷一個節點,返回每個節點所代表的key以及所有關鍵字節點(葉節點) @param node: Node對象 """ result = [] if node.is_leaf: #result.append(('', node)) if len(node) >0:#修改,避免該前綴剛好是關鍵字時搜索不到 result.append((node.key[:-1], node)) node.is_leaf=False depth_walk(node) else: return [('', node)] if node.has_subnode(): for k in node.iterkeys(): s = depth_walk(node.get(k)) #print k , s[0][0] result.extend([(k + subkey, snode) for subkey, snode in s]) return result #else: #print node.key #return [('', node)] def search(node, prefix, limit=None, is_case_sensitive=False): """ 搜索一個前綴下的所有單詞列表 遞歸 @param node: 根節點 @param prefix: 前綴 @param limit: 返回提示的數量 @param is_case_sensitive: 是否大小寫敏感 @return: [(key, node)], 包含提示關鍵字和對應葉子節點的元組列表 """ if not is_case_sensitive: prefix = prefix.lower() node = node.get_top_node(prefix) #print 'len(node):' ,len(node) #如果找不到前綴節點,代表匹配失敗,返回空 if node is None: return [] #搜索次數遞增 node.search_count += 1 if CACHED and node.cache: return node.cache[:limit] if limit is not None else node.cache #print depth_walk(node) result = [(prefix + subkey, pnode) for subkey, pnode in depth_walk(node)] result.sort(key=lambda x: x[1].weight, reverse=True) if CACHED and node.search_count >= CACHED_THREHOLD: node.cache = result[:CACHED_SIZE] #print len(result) return result[:limit] if limit is not None else result #TODO: 做成可以傳遞任意參數的,不需要每次都改 2013-10-13 done def add(node, keyword, weight=0, **kwargs): """ 加入一個單詞到樹 @param node: 根節點 @param keyword: 關鍵詞,前綴 @param weight: 權重 @param kwargs: 其他任意存儲屬性 """ one_node = node index = 0 last_index = len(keyword) - 1 for c in keyword: if c not in one_node: if index != last_index: one_node.add_subnode(Node(c, weight=weight)) else: one_node.add_subnode(Node(c, is_leaf=True, weight=weight, kwargs=kwargs)) one_node = one_node.get_subnode(c) else: one_node = one_node.get_subnode(c) if CACHED: one_node.cache = [] if index == last_index: one_node.is_leaf = True one_node.weight = weight for key, value in kwargs: setattr(one_node, key, value) index += 1 def delete(node, keyword, judge_leaf=False): """ 從樹中刪除一個單詞 @param node: 根節點 @param keyword: 關鍵詞,前綴 @param judge_leaf: 是否判定葉節點,遞歸用,外部調用使用默認值 """ # 空關鍵詞,傳入參數有問題,或者遞歸調用到了根節點,直接返回 if not keyword: return top_node = node.get_top_node(keyword) if top_node is None: return #清理緩存 if CACHED: top_node.cache = [] #遞歸往上,遇到節點是某個關鍵詞節點時,要退出 if judge_leaf: if top_node.is_leaf: return #非遞歸,調用delete else: if not top_node.is_leaf: return if top_node.has_subnode(): #存在子節點,去除其標志 done top_node.is_leaf = False return else: #不存在子節點,逐層檢查刪除節點 this_node = top_node prefix = keyword[:-1] top_node = node.get_top_node(prefix) del top_node[this_node.key] delete(node, prefix, judge_leaf=True) ############################## # 增補功能 讀數據文件建立樹 # ############################## def build(file_path, is_case_sensitive=False): """ 從文件構建數據結構, 文件必須utf-8編碼,可變更 @param file_path: 數據文件路徑,數據文件默認兩列,格式“關鍵詞\t權重" @param is_case_sensitive: 是否大小寫敏感 """ node = Node("") f = open(file_path) for line in f: line = line.strip() if not isinstance(line,unicode): line = line.decode('utf-8') parts = line.split('\t') name = parts[0] if not is_case_sensitive: name = name.lower() add(node, name, int(parts[1])) f.close() return node import time if __name__ == '__main__': #print '============ test1 ===============' #n = Node("") #default weight=0, 后面的參數可以任意加,搜索返回結果再從node中將放入對應的值取出,這里放入一個othervalue值 #add(n, u'he', othervalue="v-he") #add(n, u'her', weight=0, othervalue="v-her") #add(n, u'hero', weight=10, othervalue="v-hero") #add(n, u'hera', weight=3, othervalue="v-hera") #delete(n, u'hero') #print "search h: " #for key, node in search(n, u'h'): #print key, node, node.othervalue, id(node) #print key, node.weight #print "serch her: " #for key, node in search(n, u'her'): #print key, node, node.othervalue, id(node) #print key, node.weight start= time.clock() print '============ test2 ===============' tree = build("./shanxinpoi.txt", is_case_sensitive=False) print len(tree),'time:',time.clock()-start startline=time.clock() print u'search 秦嶺' for key, node in search(tree, u'秦嶺', limit=10): print key, node.weight print time.clock()-startline
二、Trie的Double-array Trie實現
Trie的Double-array Trie的實現參考【小白詳解 Trie 樹】和【雙數組Trie樹(DoubleArrayTrie)Java實現】
在看代碼之前提醒幾點:
(1)Comero有根據komiya-atsushi/darts-java,進行了Double-array Trie的python實現,komiya-atsushi的實現巧妙使用了文字的的編碼,以文字的編碼(一個漢字三個字符,每個字符0-256)作為【小白詳解 Trie 樹】中的字符編碼。
(2)代碼中不需要構造真正的Trie樹,直接用字符串,構造對應node,因為words是排過序的,這樣避免Trie樹在構建過程中頻繁從根節點開始重構
(3)實現中使用了了base[s]+c=t & check[t]=base[s],而非【小白詳解 Trie 樹】中的base[s]+c=t & check[t]=s
(4)komiya-atsushi實現Trie的構建、從詞典文件創建,以及對構建Trie的本地化(保存base和check,下次打開不用再重新構建)
(5)本文就改了Comero中的bug,並對代碼進行了注解。並參照dingyaguang117/DoubleArrayTrie(java)中的代碼實現了輸入提示FindAllWords方法。
(6)本文實現的FindAllWords輸入提示方法沒有用到詞頻信息,但是實現也不難
【code】
# -*- coding:utf-8 -*- # base # https://linux.thai.net/~thep/datrie/datrie.html # http://jorbe.sinaapp.com/2014/05/11/datrie/ # http://www.hankcs.com/program/java/%E5%8F%8C%E6%95%B0%E7%BB%84trie%E6%A0%91doublearraytriejava%E5%AE%9E%E7%8E%B0.html # (komiya-atsushi/darts-java | 先建立Trie樹,再構造DAT,為siblings先找到合適的空間) # https://blog.csdn.net/kissmile/article/details/47417277 # http://nark.cc/p/?p=1480 #https://github.com/midnight2104/midnight2104.github.io/blob/58b5664b3e16968dd24ac5b1b3f99dc21133b8c4/_posts/2018-8-8-%E5%8F%8C%E6%95%B0%E7%BB%84Trie%E6%A0%91(DoubleArrayTrie).md # 不需要構造真正的Trie樹,直接用字符串,構造對應node,因為words是排過序的 # todo : error info # todo : performance test # todo : resize # warning: code=0表示葉子節點可能會有隱患(正常詞匯的情況下是ok的) # 修正: 由於想要回溯字符串的效果,葉子節點和base不能重合(這樣葉子節點可以繼續記錄其他值比如頻率),葉子節點code: 0->-1 # 但是如此的話,葉子節點可能會與正常節點沖突? 找begin的使用應該是考慮到的? #from __future__ import print_function class DATrie(object): class Node(object): def __init__(self, code, depth, left, right): self.code = code self.depth = depth self.left = left self.right = right def __init__(self): self.MAX_SIZE = 2097152 # 65536 * 32 self.base = [0] * self.MAX_SIZE self.check = [-1] * self.MAX_SIZE # -1 表示空 self.used = [False] * self.MAX_SIZE self.nextCheckPos = 0 # 詳細 見后面->當數組某段使用率達到某個值時記錄下可用點,以便下次不再使用 self.size = 0 # 記錄總共用到的空間 # 需要改變size的時候調用,這里只能用於build之前。cuz沒有打算復制數據. def resize(self, size): self.MAX_SIZE = size self.base = [0] * self.MAX_SIZE self.check = [-1] * self.MAX_SIZE self.used = [False] * self.MAX_SIZE # 先決條件是self.words ordered 且沒有重復 # siblings至少會有一個 def fetch(self, parent): ###獲取parent的孩子,存放在siblings中,並記錄下其左右截至 depth = parent.depth siblings = [] # size == parent.right-parent.left i = parent.left while i < parent.right: #遍歷所有子節點,right-left+1個單詞 s = self.words[i][depth:] #詞的后半部分 if s == '': siblings.append( self.Node(code=-1, depth=depth+1, left=i, right=i+1)) # 葉子節點 else: c = ord(s[0]) #字符串中每個漢字占用3個字符(code,實際也就當成符碼),將每個字符轉為數字 ,樹實際是用這些數字構建的 #print type(s[0]),c if siblings == [] or siblings[-1].code != c: siblings.append( self.Node(code=c, depth=depth+1, left=i, right=i+1)) # 新建節點 else: # siblings[-1].code == c siblings[-1].right += 1 #已經是排過序的可以直接計數+1 i += 1 # siblings return siblings # 在insert之前,認為可以先排序詞匯,對base的分配檢查應該是有利的 # 先構建樹,再構建DAT,再銷毀樹 def build(self, words): words = sorted(list(set(words))) # 去重排序 #for word in words:print word.decode('utf-8') self.words = words # todo: 銷毀_root _root = self.Node(code=0, depth=0, left=0, right=len(self.words)) #增加第一個節點 self.base[0] = 1 siblings = self.fetch(_root) #for ii in words: print ii.decode('utf-8') #print 'siblings len',len(siblings) #for i in siblings: print i.code self.insert(siblings, 0) #插入根節點的第一層孩子 # while False: # 利用隊列來實現非遞歸構造 # pass del self.words print("DATrie builded.") def insert(self, siblings, parent_base_idx): """ parent_base_idx為父節點base index, siblings為其子節點們 """ # 暫時按komiya-atsushi/darts-java的方案 # 總的來講是從0開始分配beigin] #self.used[parent_base_idx] = True begin = 0 pos = max(siblings[0].code + 1, self.nextCheckPos) - 1 #從第一個孩子的字符碼位置開始找,因為排過序,前面的都已經使用 nonzero_num = 0 # 非零統計 first = 0 begin_ok_flag = False # 找合適的begin while not begin_ok_flag: pos += 1 if pos >= self.MAX_SIZE: raise Exception("no room, may be resize it.") if self.check[pos] != -1 or self.used[pos]: # check——check數組,used——占用標記,表明pos位置已經占用 nonzero_num += 1 # 已被使用 continue elif first == 0: self.nextCheckPos = pos # 第一個可以使用的位置,記錄?僅執行一遍 first = 1 begin = pos - siblings[0].code # 第一個孩子節點對應的begin if begin + siblings[-1].code >= self.MAX_SIZE: raise Exception("no room, may be resize it.") if self.used[begin]: #該位置已經占用 continue if len(siblings) == 1: #只有一個節點 begin_ok_flag = True break for sibling in siblings[1:]: if self.check[begin + sibling.code] == -1 and self.used[begin + sibling.code] is False: #對於sibling,begin位置可用 begin_ok_flag = True else: begin_ok_flag = False #用一個不可用,則begin不可用 break # 得到合適的begin # -- Simple heuristics -- # if the percentage of non-empty contents in check between the # index 'next_check_pos' and 'check' is greater than some constant value # (e.g. 0.9), new 'next_check_pos' index is written by 'check'. #從位置 next_check_pos 開始到 pos 間,如果已占用的空間在95%以上,下次插入節點時,直接從 pos 位置處開始查找成功獲得這一層節點的begin之后得到,影響下一次執行insert時的查找效率 if (nonzero_num / (pos - self.nextCheckPos + 1)) >= 0.95: self.nextCheckPos = pos self.used[begin] = True # base[begin] 記錄 parent chr -- 這樣就可以從節點回溯得到字符串 # 想要可以回溯的話,就不能在字符串末尾節點記錄值了,或者給葉子節點找個0以外的值? 0->-1 #self.base[begin] = parent_base_idx #【*】 #print 'begin:',begin,self.base[begin] if self.size < begin + siblings[-1].code + 1: self.size = begin + siblings[-1].code + 1 for sibling in siblings: #更新所有子節點的check base[s]+c=t & check[t]=s self.check[begin + sibling.code] = begin for sibling in siblings: # 由於是遞歸的情況,需要先處理完check # darts-java 還考慮到葉子節點有值的情況,暫時不考慮(需要記錄的話,記錄在葉子節點上) if sibling.code == -1: self.base[begin + sibling.code] = -1 * sibling.left - 1 else: new_sibings = self.fetch(sibling) h = self.insert(new_sibings, begin + sibling.code) #插入孫子節點,begin + sibling.code為子節點的位置 self.base[begin + sibling.code] = h #更新base所有子節點位置的轉移基數為[其孩子最合適的begin] return begin def search(self, word): """ 查找單詞是否存在 """ p = 0 # root if word == '': return False for c in word: c = ord(c) next = abs(self.base[p]) + c # print(c, next, self.base[next], self.check[next]) if next > self.MAX_SIZE: # 一定不存在 return False # print(self.base[self.base[p]]) if self.check[next] != abs(self.base[p]): return False p = next # print('*'*10+'\n', 0, p, self.base[self.base[p]], self.check[self.base[p]]) # 由於code=0,實際上是base[leaf_node->base+leaf_node.code],這個負的值本身沒什么用 # 修正:left code = -1 if self.base[self.base[p] - 1] < 0 and self.base[p] == self.check[self.base[p] - 1] : #print p return True else: # 不是詞尾 return False def common_prefix_search(self, content): """ 公共前綴匹配 """ # 用了 darts-java 寫法,再仔細看一下 result = [] b = self.base[0] # 從root開始 p = 0 n = 0 tmp_str = "" for c in content: c = ord(c) p = b n = self.base[p - 1] # for iden leaf if b == self.check[p - 1] and n < 0: result.append(tmp_str) tmp_str += chr(c) #print(tmp_str ) p = b + c # cur node if b == self.check[p]: b = self.base[p] # next base else: # no next node return result # 判斷最后一個node p = b n = self.base[p - 1] if b == self.check[p - 1] and n < 0: result.append(tmp_str) return result def Find_Last_Base_index(self, word): b = self.base[0] # 從root開始 p = 0 #n = 0 #print len(word) tmp_str = "" for c in word: c = ord(c) p = b p = b + c # cur node, p is new base position, b is the old if b == self.check[p]: tmp_str += chr(c) b = self.base[p] # next base else: # no next node return -1 #print '====', p, self.base[p], tmp_str.decode('utf-8') return p def GetAllChildWord(self,index): result = [] #result.append("") # print self.base[self.base[index]-1],'++++' if self.base[self.base[index]-1] <= 0 and self.base[index] == self.check[self.base[index] - 1]: result.append("") #return result for i in range(0,256): #print(chr(i)) if self.check[self.base[index]+i]==self.base[index]: #print self.base[index],(chr(i)),i for s in self.GetAllChildWord(self.base[index]+i): #print s result.append( chr(i)+s) return result def FindAllWords(self, word): result = [] last_index=self.Find_Last_Base_index(word) if last_index==-1: return result for end in self.GetAllChildWord(last_index): result.append(word+end) return result def get_string(self, chr_id): """ 從某個節點返回整個字符串, todo:改為私有 """ if self.check[chr_id] == -1: raise Exception("不存在該字符。") child = chr_id s = [] while 0 != child: base = self.check[child] print(base, child) label = chr(child - base) s.append(label) print(label) child = self.base[base] return "".join(s[::-1]) def get_use_rate(self): """ 空間使用率 """ return self.size / self.MAX_SIZE if __name__ == '__main__': words = ["一舉","一舉一動",'11', "一舉成名", "一舉成名天下知","洛陽市西工區中州中路","人民東路2號","中州東", "洛陽市","洛陽","洛神1","洛神賦","萬科","萬達3","萬科翡翠","萬達廣場", "洛川","洛川蘋果","商洛","商洛市","商朝","商業","商業模","商業模式", "萬能", "萬能膠"] #for word in words:print [word] #一個漢字的占用3個字符, words=[] for line in open('1000.txt').readlines(): # #print line.strip().decode('utf-8') words.append(line.strip()) datrie = DATrie() datrie.build(words) #for line in open('1000.txt').readlines(): # print(datrie.search(line.strip()),end=' ') #print('-'*10) #print(datrie.search("景華路")) #print('-'*10) #print(datrie.search("景華路號")) # print('-'*10) #for item in datrie.common_prefix_search("商業模式"): print(item.decode('utf-8')) #for item in datrie.common_prefix_search("商業模式"):print item.decode('utf-8') # print(datrie.common_prefix_search("一舉成名天下知")) #print(datrie.base[:1000]) # print('-'*10) # print(datrie.get_string(21520)) #index=datrie.Find_Last_Base_index("商業") #print(index),'-=-=-=' #print datrie.search("商業"),datrie.search("商業"),datrie.search("商業模式") #print index, datrie.check[datrie.base[index]+230],datrie.base[index] for ii in datrie.FindAllWords('中州中路'):print ii.decode('utf-8') #print(datrie.Find_Last_Base_index("一舉")[2].decode('utf-8')) #print()
測試數據是洛陽地址1000.txt
最后歡迎參與討論。
參考:
小白詳解Trie樹:https://segmentfault.com/a/1190000008877595
Hash實現Trie(python中的dict)(源碼):https://github.com/wklken/suggestion/blob/master/easymap/suggest.py
雙數組Trie樹(DoubleArrayTrie)Java實現(主要理解):http://www.hankcs.com/program/java/%E5%8F%8C%E6%95%B0%E7%BB%84trie%E6%A0%91doublearraytriejava%E5%AE%9E%E7%8E%B0.html
Comero對DoubleArrayTrie的python實現(源碼):https://github.com/helmz/toy_algorithms_in_python/blob/master/double_array_trie.py
DoubleArrayTrie樹的Tail壓縮,java實現(源碼):https://github.com/dingyaguang117/DoubleArrayTrie/blob/master/src/DoubleArrayTrie.java#L348