Trie


Trie,又稱單詞查找樹,Trie 樹,是一種樹形結構,是一種哈希樹的變種。典型應
用是用於統計,排序和保存大量的字符串(但不僅限於字符串),所以經常被搜索
引擎系統用於文本詞頻統計。它的優點是:利用字符串的公共前綴來減少查詢時
間,最大限度地減少無謂的字符串比較,查詢效率比哈希樹高。 ——百度百科


Trie 樹是一種性能優異的哈希樹,是一種常用的樹狀結構,常用於字符串保存、查找等操作,由於其在樹上查找字符串的方式與查字典相似,所以常被稱為字典樹。它使用不同字符串的公共前綴減少查詢時間和存儲空間,減少字符串比較,可以快速的在 \(O(n)\) 的時間內於任意多的字符串中保存或是查找字符串。

想到我們查英語字典的時候,我們對於想查的單詞(如 \(animal\)),我們會先在整本詞典中查找它的第一位字母 \(a\),再在其第一位字母\(a\)以下的區域查找第二位字母 \(n\) 所在位置……再在第五位字母 \(a\) 以下的區域查找最后一位字母 \(l\),就可以在詞典幾千個單詞中查找單詞長度 \(6\) 次找出需要查詢的單詞信息。

這就是 Trie 查找的原理,寫入的方式也比較相似。

如果我們將一些的字符串拉成鏈,全部掛在一個點上,那么就可以形成一棵龐大的樹。然后若我們將能合並的結點都合並(如 \(abcd\)\(abce\) 同時掛在根節點下,我們可以考慮將兩序列都擁有的 \(abc\) 三節點合並,在 \(c\) 結點下掛 \(d\)\(e\) 兩子節點),最后我們可以讓這課龐大的樹的占用空間大幅度下降,並且可以像上述查字典一樣的方式一層一層向下查找來找到任意一個開始時放入的字符串。這種改進后的數據結構就是Trie

如果你沒聽懂,沒有關系,請看下面的解釋。

在最開始的時候我們現在圖中找到一個起點作為樹的起點(這里記作 \(0\) 號點)。

pict-1

如果現在我們加入第一個單詞 \(he\),就該單詞拉成一條鏈掛在 \(0\) 號點下。

pict-2

再加入 \(she\),由於 \(she\)\(he\) 沒有共同前綴,所以 \(she\) 的處理方法與 \(he\) 相同。

pict-3

如果加入 \(hi\),從根節點開始向下查找,發現根節點已擁有 \(h\) 結點作為孩子,那么通過 \(h\) 向下查找,發現 \(h\) 並沒有 \(i\) 子節點,所以在 \(h\) 下面掛上一個 \(i\) 節點,那么 \(hi\) 就與 \(he\) 共用一個 \(h\) 的前綴,如下圖所示。

pict-4

再插入 \(sha\)\(sad\) 作示范,方法與上面相同

pict-5

插入時只會注意前綴相同的部分,后面即使有相同的字母也不會產生影響

可以發現,從根節點開始(不包括根節點),任意選擇一條通往葉子節點的路徑,路徑上經過的字符來連起來可以組成輸入的一個字符串。同時每一個節點不可能出現兩個擁有相同字符的孩子節點,且每個字符串在樹上只有一種表達方式。


細節

如果此時我需要插入一個 \(her\),樹會變成這樣:

pict-6

那么如何確定這個樹上到底有沒有 \(he\) 這個單詞呢?

添加結束標記

我們對每一個點加入一個布爾標記,記錄其是否為單詞結尾。為 \(True\) 表示從根節點到這里的路徑表示的是一個單詞,如果為 \(False\) 表示這不是一個單詞。

下圖將被記為單詞結尾的結點標記成紅色。那么樹就變成這樣的了。

pict-7

這樣我們就可以區分出樹上的每一個字符串了。


代碼

接下來通過一些代碼來講解 Trie 上執行操作的方法。

定義

struct node
{
    bool tail;
    int visit;
    int child[26];
};
std::vector<node> trie;

這里的結構體代表的是 Trie 上每一個結點的類型。這里的 \(tail\) 存放的布爾類型表示該節點為幾個字符串的結尾,\(visit\) 表示該節點表示的字符串被訪問過幾次(如果該節點表示的字符串多於一個,那么 \(visit\) 將成倍增加),\(child[c]\) 表示該節點的 \(c\) 孩子的數組下標(如果不存在該子節點,指向\(0\))。最后的 \(vector\) 就是 Trie 樹的表達方式了,用 \(vector\) 存放結點可以更有效地節省空間。

加入字符串

void add(std::string s)
{
    int p=0;
    int size=s.size();
    for(int sp=0;sp<size;sp++)
    {
        int c=s[sp]-'a';
        if(trie[p].child[c]==0)
        {
            trie.push_back({});
            trie[p].child[c]=trie.size()-1;
        }
        p=trie[p].child[c];
    }
    trie[p].tail++;
    return;
}

形參\(s\)就是我們需要加入到 Trie 中的字符串,我們用 \(sp\) 遍歷字符串,對於每一個 \(s[sp]\) 都會有一個 \(s[sp+1]\) 的子節點。我們從根節點開始向下搜索,如果當前結點具有我們正在匹配的這位字符,則遍歷到對應子節點,否則新建子節點,並沿該新建節點繼續向下遍歷直至整個字符串的字符全部匹配完,在最后一個節點上將結尾標記加 \(1\)

查找字符串

int find(std::string s)
{
    int point=0;
    int sp=0;
    int size=s.size();
    while(sp<size)
    {
        int c=s[sp]-'a';
        if(trie[point].child[c]!=0)
        {
            point=trie[point].child[c];
            trie[point].visit+=trie[point].tail;
            sp++;
        }
        else return -1;
    }
    if(trie[point].tail==0) return -1;
    return trie[point].visit;
}

從根節點開始向下匹配字符串,在根節點子節點中找出 \(s[1]\) 對應的子節點,再沿該子節點向下找出 \(s[2]\) 對應的子節點……直至匹配完,返回最后的節點上的訪問值(如果最后的這個結點並非一個單詞的結尾,返回 \(-1\) 表示沒有找到該字符串)。如果中途發現在某一處匹配子節點失敗,則返回 \(-1\) 表示沒有找到這個串。


代碼背景

P2580

#include <iostream>
#include <vector>
#include <string>
#include <queue>
#include <map>

struct node
{
    int tail;
    int visit;
    int child[26];
};
std::vector<node> trie;

void add(std::string s)
{
    int p=0;
    int size=s.size();
    for(int sp=0;sp<size;sp++)
    {
        int c=s[sp]-'a';
        if(trie[p].child[c]==0)
        {
            trie.push_back({});
            trie[p].child[c]=trie.size()-1;
        }
        p=trie[p].child[c];
    }
    trie[p].tail++;
    return;
}

int find(std::string s)
{
    int point=0;
    int sp=0;
    int size=s.size();
    while(sp<size)
    {
        int c=s[sp]-'a';
        if(trie[point].child[c]!=0)
        {
            point=trie[point].child[c];
            trie[point].visit+=trie[point].tail;
            sp++;
        }
        else return -1;
    }
    if(trie[point].tail==0) return -1;
    return trie[point].visit;
}

int cost[305];

int main()
{
    std::ios::sync_with_stdio(false);
    int n,m;
    std::cin>>n;
    trie.resize(1);
    for(int i=0;i<n;i++)
    {
        std::string s;
        std::cin>>s;
        add(s);
    }
    std::cin>>m;
    for(int i=0;i<m;i++)
    {
        std::string s;
        std::cin>>s;
        int t=find(s);
        if(t==-1) printf("WRONG\n");
        else if(t==1) printf("OK\n");
        else printf("REPEAT\n");
    }
    return 0;
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM