Trie樹主要應用在信息檢索領域,非常高效。今天我們講Double Array Trie,請先把Trie樹忘掉,把信息檢索忘掉,我們來講一個確定有限自動機(deterministic finite automaton ,DFA)的故事。所謂“確定有限自動機”是指給定一個狀態和一個變量時,它能跳轉到的下一個狀態也就確定下來了,同時狀態是有限的。請注意這里出現兩個名詞,一個是“狀態”,一個是“變量”,下文會舉例說明這兩個名詞的含義。
舉個例子,假設我們一共有10個漢字,每個漢字就是一個“變量”。我們為每個漢字編個序號。
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
啊 |
阿 |
埃 |
根 |
膠 |
拉 |
及 |
廷 |
伯 |
人 |
表1. “變量”的編號
這10個漢字一共可以構成6個詞語:啊,埃及,阿膠,阿根廷,阿拉伯,阿拉伯人。
這里的每個詞以及它的任意前綴都是一個“狀態”,“狀態”一共有10個:啊,阿,埃,阿根,阿根廷,阿膠,阿拉,阿拉伯,阿拉伯人,埃及
我們把DFA圖畫出來:
圖1. DFA,同時也是Trie樹
在圖中每個節點代表一個“狀態”,每條邊代表一個“變量”,並且我們把變量的編號也標在了圖中。
下面我們構造兩個int數組:base和check,它們的長度始終是一樣的。數組的長度定多少並沒有嚴格的規定,反正隨着詞語的插入,數組肯定是要擴容的。說到數組擴容,大家可以看一下java中HashMap的擴容策略,每次擴容數組的長度都會變為2的整次冪。HashMap中有這么一個精妙的函數:
//給定一個整數,返回大於等於這個數的2的整次冪 static int tableSizeFor(int cap) { int n = cap - 1; n |= n >>> 1; n |= n >>> 2; n |= n >>> 4; n |= n >>> 8; n |= n >>> 16; return (n < 0) ? 1 : n + 1; }
回到今天的正題,我們不妨把double array的初始長度就定得大一些。兩數組元素初始值均為0。
double array的初始狀態:
下標 |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
base |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
check |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
把詞添加到詞典的過程就給base和check數組中各元素賦值的過程。下面我們層次遍歷圖1所示的Trie樹。
step1.
第一層上取到3個“狀態”:啊,阿,埃。把這3個狀態按照其對應的變量的編號(查表1)放到state數組中。
下標 |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
base |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
check |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
state |
啊 |
阿 |
埃 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
step2.
當存在狀態轉移時,有
check[t]=s base[s]+c=t
其中s和t代表某個狀態在數組中的下標,c代表變量的編號。
此時層次遍歷來到了圖1所示DFA的第二層,我們看到“阿”的子節點有“阿根”、“阿膠”、“阿拉”,已知狀態“阿”的下標是2,變量“根”、“膠”、“拉”的編號依次是4、5、6,下面我們要給base[2]賦值:從小到大遍歷所有的正整數,直到發現某個數正整k滿足base[k+4]=base[k+5]=base[k+6]=check[k+4]=check[k+5]=check[k+6]=0。得到k=1,那么就把1賦給base[2],同時也確定了狀態“阿根”、“阿膠”、“阿拉”的下標依次是k+4、k+5、k+6,即5、6、7,而且check[5]=check[6]=check[7]=2。
同理,“埃”的子節點是“埃及”,狀態“埃”的下標是3,變量“及”的編號是7,此時有check[1+7]=base[1+7]=0,所以base[3]=1,狀態“埃及”的下標是8,check[8]=3。
遍歷完DFA的第二層后得到下表:
下標 |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
base |
0 |
1 |
1 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
check |
0 |
0 |
0 |
0 |
2 |
2 |
2 |
3 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
state |
啊 |
阿 |
埃 |
|
阿根 |
阿膠 |
阿拉 |
埃及 |
|
|
|
|
|
|
|
|
|
|
|
step3.
重復step2,層次遍歷完整查詢樹之后,得到:
下標 |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
base |
0 |
1 |
1 |
0 |
1 |
0 |
1 |
0 |
0 |
1 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
check |
0 |
0 |
0 |
0 |
2 |
2 |
2 |
3 |
5 |
7 |
10 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
state |
啊 |
阿 |
埃 |
|
阿根 |
阿膠 |
阿拉 |
埃及 |
阿根廷 |
阿拉伯 |
阿拉伯人 |
|
|
|
|
|
|
|
|
step4.
最后遍歷一次DFA,當某個節點已經是一個詞的結尾時,按下列方法修改其base值。
if(base[i]==0) base[i]=-i else base[i]=-base[i]
得到:
下標 |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
base |
-1 |
1 |
1 |
0 |
1 |
-6 |
1 |
-8 |
-9 |
-1 |
-11 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
check |
0 |
0 |
0 |
0 |
2 |
2 |
2 |
3 |
5 |
7 |
10 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
state |
啊 |
阿 |
埃 |
|
阿根 |
阿膠 |
阿拉 |
埃及 |
阿根廷 |
阿拉伯 |
阿拉伯人 |
|
|
|
|
|
|
|
|
double array建好之后,如果詞典中又動態地添加了一個新詞,比如“阿拉根”,那么“阿拉”的所有子孫節點在double array中的位置要重新分配。
圖2. 動態添加一個詞
首先,把“阿拉伯”和“阿拉伯人”對應的base、check值清0,把“阿拉伯”和“阿拉伯人”從state數組中刪除掉,把“阿拉”的base值清0。
下標 |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
base |
-1 |
1 |
1 |
0 |
1 |
-6 |
0 |
-8 |
-9 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
check |
0 |
0 |
0 |
0 |
2 |
2 |
2 |
3 |
5 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
state |
啊 |
阿 |
埃 |
|
阿根 |
阿膠 |
阿拉 |
埃及 |
阿根廷 |
|
|
|
|
|
|
|
|
|
|
然后,按照上面step2所述的方法把“阿拉伯”、“阿拉根”插入到double array中。變量“根”、“伯”的編號是4和9,滿足base[k+4]=base[k+9]=check[k+4]=check[k+9]=0的最小的k是6,所以base[7]=6,“阿拉伯”和“阿拉根”對應的下標是10和15。同理把“阿拉伯人”插入到double array中。
下標 |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
base |
-1 |
1 |
1 |
0 |
1 |
-6 |
6 |
-8 |
-9 |
0 |
0 |
0 |
0 |
0 |
1 |
0 |
0 |
0 |
0 |
check |
0 |
0 |
0 |
0 |
2 |
2 |
2 |
3 |
5 |
7 |
15 |
0 |
0 |
0 |
7 |
0 |
0 |
0 |
0 |
state |
啊 |
阿 |
埃 |
|
阿根 |
阿膠 |
阿拉 |
埃及 |
阿根廷 |
阿拉根 |
阿拉伯人 |
|
|
|
阿拉伯 |
|
|
|
|
最后,遍歷圖2所示的DFA,當某個節點已經是一個詞的結尾時按照step4中的方法修改其base值。
下標 |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
base |
-1 |
1 |
1 |
0 |
1 |
-6 |
6 |
-8 |
-9 |
-10 |
-11 |
0 |
0 |
0 |
-1 |
0 |
0 |
0 |
0 |
check |
0 |
0 |
0 |
0 |
2 |
2 |
2 |
3 |
5 |
7 |
15 |
0 |
0 |
0 |
7 |
0 |
0 |
0 |
0 |
state |
啊 |
阿 |
埃 |
|
阿根 |
阿膠 |
阿拉 |
埃及 |
阿根廷 |
阿拉根 |
阿拉伯人 |
|
|
|
阿拉伯 |
|
|
|
|
double array建好之后,如何查詢一個詞是否在詞典中呢?
比如要查“阿膠及”,每個字的編號是已知的,我們畫出狀態轉移圖。
變量“阿”的編號是2,base[2]=1,變量“膠”的編號是5,base[2]+5=6,我們檢查一下check[6]是否等於2。check[6]確實等於2,則繼續看下一個狀態轉移。同時我們發現base[6]是負數,這說明“阿膠”已經是一個完整的詞了。
繼續看下一個狀態轉移,base[6]=-6,負數取其相反數,base[6]=6,變量“及”的編號是7,base[6]+7=13,我們檢查一下check[13]是否等於6,發現不滿足,則“阿膠及”不是一個詞,甚至都是不是任意一個詞的前綴。
github上一個日本人貢獻了他的java版的Darts(Darts本來是一種Double Array Trie的C++實現),代碼如下:
import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; /** * DoubleArrayTrie在構建雙數組的過程中也借助於一棵傳統的Trie樹,但這棵Trie樹並沒有被保存下來, * 如果要查找以prefix為前綴的所有詞不適合用DoubleArrayTrie,應該用傳統的Trie樹。 * * @author zhangchaoyang * */ public class DoubleArrayTrie { private final static int BUF_SIZE = 16384;// 2^14,java采用unicode編碼表示所有字符,每個字符固定用兩個字節表示。考慮到每個字節的符號位都是0,所以又可以節省兩個bit private final static int UNIT_SIZE = 8; // size of int + int private static class Node { int code;// 字符的unicode編碼 int depth;// 在Trie樹中的深度 int left;// int right;// }; private int check[]; private int base[]; private boolean used[]; private int size; private int allocSize;// base數組當前的長度 private List<String> key;// 所有的詞 private int keySize; private int length[]; private int value[]; private int progress; private int nextCheckPos; int error_; // 擴充base和check數組 private int resize(int newSize) { int[] base2 = new int[newSize]; int[] check2 = new int[newSize]; boolean used2[] = new boolean[newSize]; if (allocSize > 0) { System.arraycopy(base, 0, base2, 0, allocSize);// 如果allocSize超過了base2的長度,會拋出異常 System.arraycopy(check, 0, check2, 0, allocSize); System.arraycopy(used, 0, used2, 0, allocSize); } base = base2; check = check2; used = used2; return allocSize = newSize; } private int fetch(Node parent, List<Node> siblings) { if (error_ < 0) return 0; int prev = 0; for (int i = parent.left; i < parent.right; i++) { if ((length != null ? length[i] : key.get(i).length()) < parent.depth) continue; String tmp = key.get(i); int cur = 0; if ((length != null ? length[i] : tmp.length()) != parent.depth) cur = (int) tmp.charAt(parent.depth) + 1; if (prev > cur) { error_ = -3; return 0; } if (cur != prev || siblings.size() == 0) { Node tmp_node = new Node(); tmp_node.depth = parent.depth + 1; tmp_node.code = cur; tmp_node.left = i; if (siblings.size() != 0) siblings.get(siblings.size() - 1).right = i; siblings.add(tmp_node); } prev = cur; } if (siblings.size() != 0) siblings.get(siblings.size() - 1).right = parent.right; return siblings.size(); } private int insert(List<Node> siblings) { if (error_ < 0) return 0; int begin = 0; int pos = ((siblings.get(0).code + 1 > nextCheckPos) ? siblings.get(0).code + 1 : nextCheckPos) - 1; int nonzero_num = 0; int first = 0; if (allocSize <= pos) resize(pos + 1); outer: while (true) { pos++; if (allocSize <= pos) resize(pos + 1); if (check[pos] != 0) { nonzero_num++; continue; } else if (first == 0) { nextCheckPos = pos; first = 1; } begin = pos - siblings.get(0).code; if (allocSize <= (begin + siblings.get(siblings.size() - 1).code)) { // progress can be zero double l = (1.05 > 1.0 * keySize / (progress + 1)) ? 1.05 : 1.0 * keySize / (progress + 1); resize((int) (allocSize * l)); } if (used[begin]) continue; for (int i = 1; i < siblings.size(); i++) if (check[begin + siblings.get(i).code] != 0) continue outer; break; } // -- Simple heuristics -- // if the percentage of non-empty contents in check between the // index // 'next_check_pos' and 'check' is greater than some constant value // (e.g. 0.9), // new 'next_check_pos' index is written by 'check'. if (1.0 * nonzero_num / (pos - nextCheckPos + 1) >= 0.95) nextCheckPos = pos; used[begin] = true; size = (size > begin + siblings.get(siblings.size() - 1).code + 1) ? size : begin + siblings.get(siblings.size() - 1).code + 1; for (int i = 0; i < siblings.size(); i++) check[begin + siblings.get(i).code] = begin; for (int i = 0; i < siblings.size(); i++) { List<Node> new_siblings = new ArrayList<Node>(); if (fetch(siblings.get(i), new_siblings) == 0) { base[begin + siblings.get(i).code] = (value != null) ? (-value[siblings .get(i).left] - 1) : (-siblings.get(i).left - 1); if (value != null && (-value[siblings.get(i).left] - 1) >= 0) { error_ = -2; return 0; } progress++; // if (progress_func_) (*progress_func_) (progress, // keySize); } else { int h = insert(new_siblings); base[begin + siblings.get(i).code] = h; } } return begin; } public DoubleArrayTrie() { check = null; base = null; used = null; size = 0; allocSize = 0; // no_delete_ = false; error_ = 0; } // no deconstructor // set_result omitted // the search methods returns (the list of) the value(s) instead // of (the list of) the pair(s) of value(s) and length(s) // set_array omitted // array omitted void clear() { // if (! no_delete_) check = null; base = null; used = null; allocSize = 0; size = 0; // no_delete_ = false; } public int getUnitSize() { return UNIT_SIZE; } public int getSize() { return size; } public int getTotalSize() { return size * UNIT_SIZE; } public int getNonzeroSize() { int result = 0; for (int i = 0; i < size; i++) if (check[i] != 0) result++; return result; } public int build(List<String> key) { return build(key, null, null, key.size()); } public int build(List<String> _key, int _length[], int _value[], int _keySize) { if (_keySize > _key.size() || _key == null) return 0; // progress_func_ = progress_func; key = _key; length = _length; keySize = _keySize; value = _value; progress = 0; resize(65536 * 32); base[0] = 1; nextCheckPos = 0; Node root_node = new Node(); root_node.left = 0; root_node.right = keySize; root_node.depth = 0; List<Node> siblings = new ArrayList<Node>(); fetch(root_node, siblings); insert(siblings); // size += (1 << 8 * 2) + 1; // ??? // if (size >= allocSize) resize (size); used = null; key = null; return error_; } public void open(String fileName) throws IOException { File file = new File(fileName); size = (int) file.length() / UNIT_SIZE; check = new int[size]; base = new int[size]; DataInputStream is = null; try { is = new DataInputStream(new BufferedInputStream( new FileInputStream(file), BUF_SIZE)); for (int i = 0; i < size; i++) { base[i] = is.readInt(); check[i] = is.readInt(); } } finally { if (is != null) is.close(); } } public void save(String fileName) throws IOException { DataOutputStream out = null; try { out = new DataOutputStream(new BufferedOutputStream( new FileOutputStream(fileName))); for (int i = 0; i < size; i++) { out.writeInt(base[i]); out.writeInt(check[i]); } out.close(); } finally { if (out != null) out.close(); } } public int exactMatchSearch(String key) { return exactMatchSearch(key, 0, 0, 0); } public int exactMatchSearch(String key, int pos, int len, int nodePos) { if (len <= 0) len = key.length(); if (nodePos <= 0) nodePos = 0; int result = -1; char[] keyChars = key.toCharArray(); int b = base[nodePos]; int p; for (int i = pos; i < len; i++) { p = b + (int) (keyChars[i]) + 1; if (b == check[p]) b = base[p]; else return result; } p = b; int n = base[p]; if (b == check[p] && n < 0) { result = -n - 1; } return result; } public List<Integer> commonPrefixSearch(String key) { return commonPrefixSearch(key, 0, 0, 0); } public List<Integer> commonPrefixSearch(String key, int pos, int len, int nodePos) { if (len <= 0) len = key.length(); if (nodePos <= 0) nodePos = 0; List<Integer> result = new ArrayList<Integer>(); char[] keyChars = key.toCharArray(); int b = base[nodePos]; int n; int p; for (int i = pos; i < len; i++) { p = b; n = base[p]; if (b == check[p] && n < 0) { result.add(-n - 1); } p = b + (int) (keyChars[i]) + 1; if (b == check[p]) b = base[p]; else return result; } p = b; n = base[p]; if (b == check[p] && n < 0) { result.add(-n - 1); } return result; } // debug public void dump() { for (int i = 0; i < size; i++) { System.err.println("i: " + i + " [" + base[i] + ", " + check[i] + "]"); } } }
public class TestDoubleArrayTrie { /** * 檢索key的前綴命中了詞典中的哪些詞<br> * key的前綴有多個,所以有可能命中詞典中的多個詞 */ @Test public void testPrefixMatch() { DoubleArrayTrie adt = new DoubleArrayTrie(); List<String> list = new ArrayList<String>(); list.add("阿膠"); list.add("阿拉伯"); list.add("阿拉伯人"); list.add("埃及"); // 所有詞必須先排序 Collections.sort(list); // 構建DoubleArrayTrie adt.build(list); String key = "阿拉伯人"; // 檢索key的前綴命中了詞典中的哪些詞 List<Integer> rect = adt.commonPrefixSearch(key); for (int index : rect) { System.out.println("前綴 " + list.get(index) + " matched"); } System.out.println("================="); } /** * 檢索key是否完全命中了詞典中的某個詞 */ @Test public void testFullMatch() { DoubleArrayTrie adt = new DoubleArrayTrie(); List<String> list = new ArrayList<String>(); list.add("阿膠"); list.add("阿拉伯"); list.add("阿拉伯人"); list.add("埃及"); // 所有詞必須先排序 Collections.sort(list); // 構建DoubleArrayTrie adt.build(list); String key = "阿拉"; // 檢索key是否完全命中了詞典中的某個詞 int index = adt.exactMatchSearch(key); if (index >= 0) { System.out.println(key + " match " + list.get(index)); } else { System.out.println(key + " not match any term"); } key = "阿拉伯"; index = adt.exactMatchSearch(key); if (index >= 0) { System.out.println(key + " match " + list.get(index)); } else { System.out.println(key + " not match any term"); } key = "阿拉伯人"; index = adt.exactMatchSearch(key); if (index >= 0) { System.out.println(key + " match " + list.get(index)); } else { System.out.println(key + " not match any term"); } System.out.println("================="); } }