Double Array Trie


Trie樹主要應用在信息檢索領域,非常高效。今天我們講Double Array Trie,請先把Trie樹忘掉,把信息檢索忘掉,我們來講一個確定有限自動機(deterministic finite automaton ,DFA)的故事。所謂“確定有限自動機”是指給定一個狀態和一個變量時,它能跳轉到的下一個狀態也就確定下來了,同時狀態是有限的。請注意這里出現兩個名詞,一個是“狀態”,一個是“變量”,下文會舉例說明這兩個名詞的含義。

舉個例子,假設我們一共有10個漢字,每個漢字就是一個“變量”。我們為每個漢字編個序號。

 

1

2

3

4

5

6

7

8

9

10

             表1. “變量”的編號

這10個漢字一共可以構成6個詞語:啊,埃及,阿膠,阿根廷,阿拉伯,阿拉伯人。         

這里的每個詞以及它的任意前綴都是一個“狀態”,“狀態”一共有10個:啊,阿,埃,阿根,阿根廷,阿膠,阿拉,阿拉伯,阿拉伯人,埃及

我們把DFA圖畫出來:

        圖1. DFA,同時也是Trie樹

在圖中每個節點代表一個“狀態”,每條邊代表一個“變量”,並且我們把變量的編號也標在了圖中。

下面我們構造兩個int數組:base和check,它們的長度始終是一樣的。數組的長度定多少並沒有嚴格的規定,反正隨着詞語的插入,數組肯定是要擴容的。說到數組擴容,大家可以看一下java中HashMap的擴容策略,每次擴容數組的長度都會變為2的整次冪。HashMap中有這么一個精妙的函數:

//給定一個整數,返回大於等於這個數的2的整次冪
static int tableSizeFor(int cap) {
        int n = cap - 1;
        n |= n >>> 1;
        n |= n >>> 2;
        n |= n >>> 4;
        n |= n >>> 8;
        n |= n >>> 16;
        return (n < 0) ? 1 :  n + 1;
}

回到今天的正題,我們不妨把double array的初始長度就定得大一些。兩數組元素初始值均為0。

double array的初始狀態:

下標

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

base

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

check

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

state

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

把詞添加到詞典的過程就給base和check數組中各元素賦值的過程。下面我們層次遍歷圖1所示的Trie樹。

step1.

第一層上取到3個“狀態”:啊,阿,埃。把這3個狀態按照其對應的變量的編號(查表1)放到state數組中。

下標

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

base

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

check

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

state

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

step2.

當存在狀態轉移時,有

check[t]=s
base[s]+c=t

其中s和t代表某個狀態在數組中的下標,c代表變量的編號。

此時層次遍歷來到了圖1所示DFA的第二層,我們看到“阿”的子節點有“阿根”、“阿膠”、“阿拉”,已知狀態“阿”的下標是2,變量“根”、“膠”、“拉”的編號依次是4、5、6,下面我們要給base[2]賦值:從小到大遍歷所有的正整數,直到發現某個數正整k滿足base[k+4]=base[k+5]=base[k+6]=check[k+4]=check[k+5]=check[k+6]=0。得到k=1,那么就把1賦給base[2],同時也確定了狀態“阿根”、“阿膠”、“阿拉”的下標依次是k+4、k+5、k+6,即5、6、7,而且check[5]=check[6]=check[7]=2。

同理,“埃”的子節點是“埃及”,狀態“埃”的下標是3,變量“及”的編號是7,此時有check[1+7]=base[1+7]=0,所以base[3]=1,狀態“埃及”的下標是8,check[8]=3。

遍歷完DFA的第二層后得到下表:

下標

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

base

0

1

1

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

0

check

0

0

0

0

2

2

2

3

0

0

0

0

0

0

0

0

0

0

0

state

 

阿根

阿膠

阿拉

埃及

 

 

 

 

 

 

 

 

 

 

 

step3.

重復step2,層次遍歷完整查詢樹之后,得到:

下標

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

base

0

1

1

0

1

0

1

0

0

1

0

0

0

0

0

0

0

0

0

check

0

0

0

0

2

2

2

3

5

7

10

0

0

0

0

0

0

0

0

state

 

阿根

阿膠

阿拉

埃及

阿根廷

阿拉伯

阿拉伯人

 

 

 

 

 

 

 

 

step4.

最后遍歷一次DFA,當某個節點已經是一個詞的結尾時,按下列方法修改其base值。

if(base[i]==0)
    base[i]=-i
else
    base[i]=-base[i]

得到:

下標

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

base

-1

1

1

0

1

-6

1

-8

-9

-1

-11

0

0

0

0

0

0

0

0

check

0

0

0

0

2

2

2

3

5

7

10

0

0

0

0

0

0

0

0

state

 

阿根

阿膠

阿拉

埃及

阿根廷

阿拉伯

阿拉伯人

 

 

 

 

 

 

 

 

double array建好之后,如果詞典中又動態地添加了一個新詞,比如“阿拉根”,那么“阿拉”的所有子孫節點在double array中的位置要重新分配。

 

圖2. 動態添加一個詞

首先,把“阿拉伯”和“阿拉伯人”對應的base、check值清0,把“阿拉伯”和“阿拉伯人”從state數組中刪除掉,把“阿拉”的base值清0。

下標

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

base

-1

1

1

0

1

-6

0

-8

-9

0

0

0

0

0

0

0

0

0

0

check

0

0

0

0

2

2

2

3

5

0

0

0

0

0

0

0

0

0

0

state

 

阿根

阿膠

阿拉

埃及

阿根廷

 

 

 

 

 

 

 

 

 

 

然后,按照上面step2所述的方法把“阿拉伯”、“阿拉根”插入到double array中。變量“根”、“伯”的編號是4和9,滿足base[k+4]=base[k+9]=check[k+4]=check[k+9]=0的最小的k是6,所以base[7]=6,“阿拉伯”和“阿拉根”對應的下標是10和15。同理把“阿拉伯人”插入到double array中。

下標

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

base

-1

1

1

0

1

-6

6

-8

-9

0

0

0

0

0

1

0

0

0

0

check

0

0

0

0

2

2

2

3

5

7

15

0

0

0

7

0

0

0

0

state

 

阿根

阿膠

阿拉

埃及

阿根廷

阿拉根

阿拉伯人

 

 

 

阿拉伯

 

 

 

 

最后,遍歷圖2所示的DFA,當某個節點已經是一個詞的結尾時按照step4中的方法修改其base值。

下標

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

base

-1

1

1

0

1

-6

6

-8

-9

-10

-11

0

0

0

-1

0

0

0

0

check

0

0

0

0

2

2

2

3

5

7

15

0

0

0

7

0

0

0

0

state

 

阿根

阿膠

阿拉

埃及

阿根廷

阿拉根

阿拉伯人

 

 

 

阿拉伯

 

 

 

 

 

double array建好之后,如何查詢一個詞是否在詞典中呢?

比如要查“阿膠及”,每個字的編號是已知的,我們畫出狀態轉移圖。

變量“阿”的編號是2,base[2]=1,變量“膠”的編號是5,base[2]+5=6,我們檢查一下check[6]是否等於2。check[6]確實等於2,則繼續看下一個狀態轉移。同時我們發現base[6]是負數,這說明“阿膠”已經是一個完整的詞了。

繼續看下一個狀態轉移,base[6]=-6,負數取其相反數,base[6]=6,變量“及”的編號是7,base[6]+7=13,我們檢查一下check[13]是否等於6,發現不滿足,則“阿膠及”不是一個詞,甚至都是不是任意一個詞的前綴。

github上一個日本人貢獻了他的java版的Darts(Darts本來是一種Double Array Trie的C++實現),代碼如下:

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/**
 * DoubleArrayTrie在構建雙數組的過程中也借助於一棵傳統的Trie樹,但這棵Trie樹並沒有被保存下來,
 * 如果要查找以prefix為前綴的所有詞不適合用DoubleArrayTrie,應該用傳統的Trie樹。
 * 
 * @author zhangchaoyang
 *
 */
public class DoubleArrayTrie {
	private final static int BUF_SIZE = 16384;// 2^14,java采用unicode編碼表示所有字符,每個字符固定用兩個字節表示。考慮到每個字節的符號位都是0,所以又可以節省兩個bit
	private final static int UNIT_SIZE = 8; // size of int + int

	private static class Node {
		int code;// 字符的unicode編碼
		int depth;// 在Trie樹中的深度
		int left;//
		int right;//
	};

	private int check[];
	private int base[];

	private boolean used[];
	private int size;
	private int allocSize;// base數組當前的長度
	private List<String> key;// 所有的詞
	private int keySize;
	private int length[];
	private int value[];
	private int progress;
	private int nextCheckPos;
	int error_;

	// 擴充base和check數組
	private int resize(int newSize) {
		int[] base2 = new int[newSize];
		int[] check2 = new int[newSize];
		boolean used2[] = new boolean[newSize];
		if (allocSize > 0) {
			System.arraycopy(base, 0, base2, 0, allocSize);// 如果allocSize超過了base2的長度,會拋出異常
			System.arraycopy(check, 0, check2, 0, allocSize);
			System.arraycopy(used, 0, used2, 0, allocSize);
		}

		base = base2;
		check = check2;
		used = used2;

		return allocSize = newSize;
	}

	private int fetch(Node parent, List<Node> siblings) {
		if (error_ < 0)
			return 0;

		int prev = 0;

		for (int i = parent.left; i < parent.right; i++) {
			if ((length != null ? length[i] : key.get(i).length()) < parent.depth)
				continue;

			String tmp = key.get(i);

			int cur = 0;
			if ((length != null ? length[i] : tmp.length()) != parent.depth)
				cur = (int) tmp.charAt(parent.depth) + 1;

			if (prev > cur) {
				error_ = -3;
				return 0;
			}

			if (cur != prev || siblings.size() == 0) {
				Node tmp_node = new Node();
				tmp_node.depth = parent.depth + 1;
				tmp_node.code = cur;
				tmp_node.left = i;
				if (siblings.size() != 0)
					siblings.get(siblings.size() - 1).right = i;

				siblings.add(tmp_node);
			}

			prev = cur;
		}

		if (siblings.size() != 0)
			siblings.get(siblings.size() - 1).right = parent.right;

		return siblings.size();
	}

	private int insert(List<Node> siblings) {
		if (error_ < 0)
			return 0;

		int begin = 0;
		int pos = ((siblings.get(0).code + 1 > nextCheckPos) ? siblings.get(0).code + 1
				: nextCheckPos) - 1;
		int nonzero_num = 0;
		int first = 0;

		if (allocSize <= pos)
			resize(pos + 1);

		outer: while (true) {
			pos++;

			if (allocSize <= pos)
				resize(pos + 1);

			if (check[pos] != 0) {
				nonzero_num++;
				continue;
			} else if (first == 0) {
				nextCheckPos = pos;
				first = 1;
			}

			begin = pos - siblings.get(0).code;
			if (allocSize <= (begin + siblings.get(siblings.size() - 1).code)) {
				// progress can be zero
				double l = (1.05 > 1.0 * keySize / (progress + 1)) ? 1.05 : 1.0
						* keySize / (progress + 1);
				resize((int) (allocSize * l));
			}

			if (used[begin])
				continue;

			for (int i = 1; i < siblings.size(); i++)
				if (check[begin + siblings.get(i).code] != 0)
					continue outer;

			break;
		}

		// -- Simple heuristics --
		// if the percentage of non-empty contents in check between the
		// index
		// 'next_check_pos' and 'check' is greater than some constant value
		// (e.g. 0.9),
		// new 'next_check_pos' index is written by 'check'.
		if (1.0 * nonzero_num / (pos - nextCheckPos + 1) >= 0.95)
			nextCheckPos = pos;

		used[begin] = true;
		size = (size > begin + siblings.get(siblings.size() - 1).code + 1) ? size
				: begin + siblings.get(siblings.size() - 1).code + 1;

		for (int i = 0; i < siblings.size(); i++)
			check[begin + siblings.get(i).code] = begin;

		for (int i = 0; i < siblings.size(); i++) {
			List<Node> new_siblings = new ArrayList<Node>();

			if (fetch(siblings.get(i), new_siblings) == 0) {
				base[begin + siblings.get(i).code] = (value != null) ? (-value[siblings
						.get(i).left] - 1) : (-siblings.get(i).left - 1);

				if (value != null && (-value[siblings.get(i).left] - 1) >= 0) {
					error_ = -2;
					return 0;
				}

				progress++;
				// if (progress_func_) (*progress_func_) (progress,
				// keySize);
			} else {
				int h = insert(new_siblings);
				base[begin + siblings.get(i).code] = h;
			}
		}
		return begin;
	}

	public DoubleArrayTrie() {
		check = null;
		base = null;
		used = null;
		size = 0;
		allocSize = 0;
		// no_delete_ = false;
		error_ = 0;
	}

	// no deconstructor

	// set_result omitted
	// the search methods returns (the list of) the value(s) instead
	// of (the list of) the pair(s) of value(s) and length(s)

	// set_array omitted
	// array omitted

	void clear() {
		// if (! no_delete_)
		check = null;
		base = null;
		used = null;
		allocSize = 0;
		size = 0;
		// no_delete_ = false;
	}

	public int getUnitSize() {
		return UNIT_SIZE;
	}

	public int getSize() {
		return size;
	}

	public int getTotalSize() {
		return size * UNIT_SIZE;
	}

	public int getNonzeroSize() {
		int result = 0;
		for (int i = 0; i < size; i++)
			if (check[i] != 0)
				result++;
		return result;
	}

	public int build(List<String> key) {
		return build(key, null, null, key.size());
	}

	public int build(List<String> _key, int _length[], int _value[],
			int _keySize) {
		if (_keySize > _key.size() || _key == null)
			return 0;

		// progress_func_ = progress_func;
		key = _key;
		length = _length;
		keySize = _keySize;
		value = _value;
		progress = 0;

		resize(65536 * 32);

		base[0] = 1;
		nextCheckPos = 0;

		Node root_node = new Node();
		root_node.left = 0;
		root_node.right = keySize;
		root_node.depth = 0;

		List<Node> siblings = new ArrayList<Node>();
		fetch(root_node, siblings);
		insert(siblings);

		// size += (1 << 8 * 2) + 1; // ???
		// if (size >= allocSize) resize (size);

		used = null;
		key = null;

		return error_;
	}

	public void open(String fileName) throws IOException {
		File file = new File(fileName);
		size = (int) file.length() / UNIT_SIZE;
		check = new int[size];
		base = new int[size];

		DataInputStream is = null;
		try {
			is = new DataInputStream(new BufferedInputStream(
					new FileInputStream(file), BUF_SIZE));
			for (int i = 0; i < size; i++) {
				base[i] = is.readInt();
				check[i] = is.readInt();
			}
		} finally {
			if (is != null)
				is.close();
		}
	}

	public void save(String fileName) throws IOException {
		DataOutputStream out = null;
		try {
			out = new DataOutputStream(new BufferedOutputStream(
					new FileOutputStream(fileName)));
			for (int i = 0; i < size; i++) {
				out.writeInt(base[i]);
				out.writeInt(check[i]);
			}
			out.close();
		} finally {
			if (out != null)
				out.close();
		}
	}

	public int exactMatchSearch(String key) {
		return exactMatchSearch(key, 0, 0, 0);
	}

	public int exactMatchSearch(String key, int pos, int len, int nodePos) {
		if (len <= 0)
			len = key.length();
		if (nodePos <= 0)
			nodePos = 0;

		int result = -1;

		char[] keyChars = key.toCharArray();

		int b = base[nodePos];
		int p;

		for (int i = pos; i < len; i++) {
			p = b + (int) (keyChars[i]) + 1;
			if (b == check[p])
				b = base[p];
			else
				return result;
		}

		p = b;
		int n = base[p];
		if (b == check[p] && n < 0) {
			result = -n - 1;
		}
		return result;
	}

	public List<Integer> commonPrefixSearch(String key) {
		return commonPrefixSearch(key, 0, 0, 0);
	}

	public List<Integer> commonPrefixSearch(String key, int pos, int len,
			int nodePos) {
		if (len <= 0)
			len = key.length();
		if (nodePos <= 0)
			nodePos = 0;

		List<Integer> result = new ArrayList<Integer>();

		char[] keyChars = key.toCharArray();

		int b = base[nodePos];
		int n;
		int p;

		for (int i = pos; i < len; i++) {
			p = b;
			n = base[p];

			if (b == check[p] && n < 0) {
				result.add(-n - 1);
			}

			p = b + (int) (keyChars[i]) + 1;
			if (b == check[p])
				b = base[p];
			else
				return result;
		}

		p = b;
		n = base[p];

		if (b == check[p] && n < 0) {
			result.add(-n - 1);
		}

		return result;
	}

	// debug
	public void dump() {
		for (int i = 0; i < size; i++) {
			System.err.println("i: " + i + " [" + base[i] + ", " + check[i]
					+ "]");
		}
	}
}

 

public class TestDoubleArrayTrie {

	/**
	 * 檢索key的前綴命中了詞典中的哪些詞<br>
	 * key的前綴有多個,所以有可能命中詞典中的多個詞
	 */
	@Test
	public void testPrefixMatch() {
		DoubleArrayTrie adt = new DoubleArrayTrie();
		List<String> list = new ArrayList<String>();
		list.add("阿膠");
		list.add("阿拉伯");
		list.add("阿拉伯人");
		list.add("埃及");
		// 所有詞必須先排序
		Collections.sort(list);
		// 構建DoubleArrayTrie
		adt.build(list);
		String key = "阿拉伯人";
		// 檢索key的前綴命中了詞典中的哪些詞
		List<Integer> rect = adt.commonPrefixSearch(key);
		for (int index : rect) {
			System.out.println("前綴  " + list.get(index) + " matched");
		}
		System.out.println("=================");
	}

	/**
	 * 檢索key是否完全命中了詞典中的某個詞
	 */
	@Test
	public void testFullMatch() {
		DoubleArrayTrie adt = new DoubleArrayTrie();
		List<String> list = new ArrayList<String>();
		list.add("阿膠");
		list.add("阿拉伯");
		list.add("阿拉伯人");
		list.add("埃及");
		// 所有詞必須先排序
		Collections.sort(list);
		// 構建DoubleArrayTrie
		adt.build(list);
		String key = "阿拉";
		// 檢索key是否完全命中了詞典中的某個詞
		int index = adt.exactMatchSearch(key);
		if (index >= 0) {
			System.out.println(key + " match " + list.get(index));
		} else {
			System.out.println(key + " not match any term");
		}
		key = "阿拉伯";
		index = adt.exactMatchSearch(key);
		if (index >= 0) {
			System.out.println(key + " match " + list.get(index));
		} else {
			System.out.println(key + " not match any term");
		}
		key = "阿拉伯人";
		index = adt.exactMatchSearch(key);
		if (index >= 0) {
			System.out.println(key + " match " + list.get(index));
		} else {
			System.out.println(key + " not match any term");
		}
		System.out.println("=================");
	}
}

  


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM