跳转至

字典樹 (Trie)

定義

字典樹,英文名 trie。顧名思義,就是一個像字典一樣的樹。

引入

先放一張圖:

trie1

可以發現,這棵字典樹用邊來代表字母,而從根結點到樹上某一結點的路徑就代表了一個字符串。舉個例子,\(1\to4\to 8\to 12\) 表示的就是字符串 caa

trie 的結構非常好懂,我們用 \(\delta(u,c)\) 表示結點 \(u\)\(c\) 字符指向的下一個結點,或着説是結點 \(u\) 代表的字符串後面添加一個字符 \(c\) 形成的字符串的結點。(\(c\) 的取值範圍和字符集大小有關,不一定是 \(0\sim 26\)。)

有時需要標記插入進 trie 的是哪些字符串,每次插入完成時在這個字符串所代表的節點處打上標記即可。

實現

放一個結構體封裝的模板:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
struct trie {
  int nex[100000][26], cnt;
  bool exist[100000];  // 該結點結尾的字符串是否存在

  void insert(char *s, int l) {  // 插入字符串
    int p = 0;
    for (int i = 0; i < l; i++) {
      int c = s[i] - 'a';
      if (!nex[p][c]) nex[p][c] = ++cnt;  // 如果沒有,就添加結點
      p = nex[p][c];
    }
    exist[p] = 1;
  }

  bool find(char *s, int l) {  // 查找字符串
    int p = 0;
    for (int i = 0; i < l; i++) {
      int c = s[i] - 'a';
      if (!nex[p][c]) return 0;
      p = nex[p][c];
    }
    return exist[p];
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class trie:
    def __init__(self):
        self.nex = [[0 for i in range(26)] for j in range(100000)]
        self.cnt = 0
        self.exist = [False] * 100000  # 該結點結尾的字符串是否存在

    def insert(self, s):  # 插入字符串
        p = 0
        for i in s:
            c = ord(i) - ord('a')
            if not self.nex[p][c]:
                self.cnt += 1
                self.nex[p][c] = self.cnt  # 如果沒有,就添加結點
            p = self.nex[p][c]
        self.exist[p] = True

    def find(self, s):  # 查找字符串
        p = 0
        for i in s:
            c = ord(i) - ord('a')
            if not self.nex[p][c]:
                return False
            p = self.nex[p][c]
        return self.exist[p]

應用

檢索字符串

字典樹最基礎的應用——查找一個字符串是否在「字典」中出現過。

於是他錯誤的點名開始了

給你 \(n\) 個名字串,然後進行 \(m\) 次點名,每次你需要回答「名字不存在」、「第一次點到這個名字」、「已經點過這個名字」之一。

\(1\le n\le 10^4\)\(1\le m\le 10^5\),所有字符串長度不超過 \(50\)

題解

對所有名字建 trie,再在 trie 中查詢字符串是否存在、是否已經點過名,第一次點名時標記為點過名。

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <cstdio>

const int N = 500010;

char s[60];
int n, m, ch[N][26], tag[N], tot = 1;

int main() {
  scanf("%d", &n);

  for (int i = 1; i <= n; ++i) {
    scanf("%s", s + 1);
    int u = 1;
    for (int j = 1; s[j]; ++j) {
      int c = s[j] - 'a';
      if (!ch[u][c])
        ch[u][c] =
            ++tot;  // 如果这个节点的子节点中没有这个字符,添加上并将该字符的节点号记录为++tot
      u = ch[u][c];  // 往更深一层搜索
    }
    tag[u] = 1;  // 最后一个字符为节点 u 的名字未被访问到记录为 1
  }

  scanf("%d", &m);

  while (m--) {
    scanf("%s", s + 1);
    int u = 1;
    for (int j = 1; s[j]; ++j) {
      int c = s[j] - 'a';
      u = ch[u][c];
      if (!u) break;  // 不存在对应字符的出边说明名字不存在
    }
    if (tag[u] == 1) {
      tag[u] = 2;  // 最后一个字符为节点 u 的名字已经被访问
      puts("OK");
    } else if (tag[u] == 2)  // 已经被访问,重复访问
      puts("REPEAT");
    else
      puts("WRONG");
  }

  return 0;
}

AC 自動機

trie 是 AC 自動機 的一部分。

維護異或極值

將數的二進制表示看做一個字符串,就可以建出字符集為 \(\{0,1\}\) 的 trie 樹。

BZOJ1954 最長異或路徑

給你一棵帶邊權的樹,求 \((u, v)\) 使得 \(u\)\(v\) 的路徑上的邊權異或和最大,輸出這個最大值。這裏的異或和指的是所有邊權的異或。

點數不超過 \(10^5\),邊權在 \([0,2^{31})\) 內。

題解

隨便指定一個根 \(root\),用 \(T(u, v)\) 表示 \(u\)\(v\) 之間的路徑的邊權異或和,那麼 \(T(u,v)=T(root, u)\oplus T(root,v)\),因為 LCA 以上的部分異或兩次抵消了。

那麼,如果將所有 \(T(root, u)\) 插入到一棵 trie 中,就可以對每個 \(T(root, u)\) 快速求出和它異或和最大的 \(T(root, v)\)

從 trie 的根開始,如果能向和 \(T(root, u)\) 的當前位不同的子樹走,就向那邊走,否則沒有選擇。

貪心的正確性:如果這麼走,這一位為 \(1\);如果不這麼走,這一位就會為 \(0\)。而高位是需要優先儘量大的。

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <algorithm>
#include <cstdio>
using namespace std;

const int N = 100010;

int head[N], nxt[N << 1], to[N << 1], weight[N << 1], cnt;
int n, dis[N], ch[N << 5][2], tot = 1, ans;

void insert(int x) {
  for (int i = 30, u = 1; i >= 0; --i) {
    int c = ((x >> i) & 1);  // 二进制一位一位向下取
    if (!ch[u][c]) ch[u][c] = ++tot;
    u = ch[u][c];
  }
}

void get(int x) {
  int res = 0;
  for (int i = 30, u = 1; i >= 0; --i) {
    int c = ((x >> i) & 1);
    if (ch[u][c ^ 1]) {  // 如果能向和当前位不同的子树走,就向那边走
      u = ch[u][c ^ 1];
      res |= (1 << i);
    } else
      u = ch[u][c];
  }
  ans = max(ans, res);  // 更新答案
}

void add(int u, int v, int w) {  // 建边
  nxt[++cnt] = head[u];
  head[u] = cnt;
  to[cnt] = v;
  weight[cnt] = w;
}

void dfs(int u, int fa) {
  insert(dis[u]);
  get(dis[u]);
  for (int i = head[u]; i; i = nxt[i]) {  // 遍历子节点
    int v = to[i];
    if (v == fa) continue;
    dis[v] = dis[u] ^ weight[i];
    dfs(v, u);
  }
}

int main() {
  scanf("%d", &n);

  for (int i = 1; i < n; ++i) {
    int u, v, w;
    scanf("%d%d%d", &u, &v, &w);
    add(u, v, w);  // 双向边
    add(v, u, w);
  }

  dfs(1, 0);

  printf("%d", ans);

  return 0;
}

維護異或和

01-trie 是指字符集為 \(\{0,1\}\) 的 trie。01-trie 可以用來維護一些數字的異或和,支持修改(刪除 + 重新插入),和全局加一(即:讓其所維護所有數值遞增 1,本質上是一種特殊的修改操作)。

如果要維護異或和,需要按值從低位到高位建立 trie。

一個約定:文中説當前節點 往上 指當前節點到根這條路徑,當前節點 往下 指當前結點的子樹。

插入 & 刪除

如果要維護異或和,我們 只需要 知道某一位上 01 個數的 奇偶性 即可,也就是對於數字 1 來説,當且僅當這一位上數字 1 的個數為奇數時,這一位上的數字才是 1,請時刻記住這段文字:如果只是維護異或和,我們只需要知道某一位上 1 的數量即可,而不需要知道 trie 到底維護了哪些數字。

對於每一個節點,我們需要記錄以下三個量:

  • ch[o][0/1] 指節點 o 的兩個兒子,ch[o][0] 指下一位是 0,同理 ch[o][1] 指下一位是 1
  • w[o] 指節點 o 到其父親節點這條邊上數值的數量(權值)。每插入一個數字 xx 二進制拆分後在 trie 上 路徑的權值都會 +1
  • xorv[o] 指以 o 為根的子樹維護的異或和。

具體維護結點的代碼如下所示。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
void maintain(int o) {
  w[o] = xorv[o] = 0;
  if (ch[o][0]) {
    w[o] += w[ch[o][0]];
    xorv[o] ^= xorv[ch[o][0]] << 1;
  }
  if (ch[o][1]) {
    w[o] += w[ch[o][1]];
    xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1);
  }
  // w[o] = w[o] & 1;
  // 只需知道奇偶性即可,不需要具體的值。當然這句話刪掉也可以,因為上文就只利用了他的奇偶性。
}

插入和刪除的代碼非常相似。

需要注意的地方就是:

  • 這裏的 MAXH 指 trie 的深度,也就是強制讓每一個葉子節點到根的距離為 MAXH。對於一些比較小的值,可能有時候不需要建立這麼深(例如:如果插入數字 4,分解成二進制後為 100,從根開始插入 001 這三位即可),但是我們強制插入 MAXH 位。這樣做的目的是為了便於全局 +1 時處理進位。例如:如果原數字是 311),遞增之後變成 4100),如果當初插入 3 時只插入了 2 位,那這裏的進位就沒了。

  • 插入和刪除,只需要修改葉子節點的 w[] 即可,在回溯的過程中一路維護即可。

實現
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
namespace trie {
const int MAXH = 21;
int ch[_ * (MAXH + 1)][2], w[_ * (MAXH + 1)], xorv[_ * (MAXH + 1)];
int tot = 0;

int mknode() {
  ++tot;
  ch[tot][1] = ch[tot][0] = w[tot] = xorv[tot] = 0;
  return tot;
}

void maintain(int o) {
  w[o] = xorv[o] = 0;
  if (ch[o][0]) {
    w[o] += w[ch[o][0]];
    xorv[o] ^= xorv[ch[o][0]] << 1;
  }
  if (ch[o][1]) {
    w[o] += w[ch[o][1]];
    xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1);
  }
  w[o] = w[o] & 1;
}

void insert(int &o, int x, int dp) {
  if (!o) o = mknode();
  if (dp > MAXH) return (void)(w[o]++);
  insert(ch[o][x & 1], x >> 1, dp + 1);
  maintain(o);
}

void erase(int o, int x, int dp) {
  if (dp > 20) return (void)(w[o]--);
  erase(ch[o][x & 1], x >> 1, dp + 1);
  maintain(o);
}
}  // namespace trie

全局加一

所謂全局加一就是指,讓這棵 trie 中所有的數值 +1

形式化的講,設 trie 中維護的數值有 \(V_1, V_2, V_3 \dots V_n\), 全局加一後 其中維護的值應該變成 \(V_1+1, V_2+1, V_3+1 \dots V_n+1\)

1
2
3
4
5
void addall(int o) {
  swap(ch[o][0], ch[o][1]);
  if (ch[o][0]) addall(ch[o][0]);
  maintain(o);
}
過程

我們思考一下二進制意義下 +1 是如何操作的。

我們只需要從低位到高位開始找第一個出現的 0,把它變成 1,然後這個位置後面的 1 都變成 0 即可。

下面給出幾個例子感受一下:(括號內的數字表示其對應的十進制數字)

1
2
3
4
5
1000(8)  + 1 = 1001(9)  ;
10011(19) + 1 = 10100(20) ;
11111(31) + 1 = 100000(32);
10101(21) + 1 = 10110(22) ;
100000000111111(16447) + 1 = 100000001000000(16448);

對應 trie 的操作,其實就是交換其左右兒子,順着 交換後0 邊往下遞歸操作即可。

回顧一下 w[o] 的定義:w[o] 指節點 o 到其父親節點這條邊上數值的數量(權值)。

有沒有感覺這個定義有點怪呢?如果在父親結點存儲到兩個兒子的這條邊的邊權也許會更接近於習慣。但是在這裏,在交換左右兒子的時候,在兒子結點存儲到父親這條邊的距離,顯然更加方便。

01-trie 合併

指的是將上述的兩個 01-trie 進行合併,同時合併維護的信息。

可能關於合併 trie 的文章比較少,其實合併 trie 和合併線段樹的思路非常相似,可以搜索「合併線段樹」來學習如何合併 trie。

其實合併 trie 非常簡單,就是考慮一下我們有一個 int merge(int a, int b) 函數,這個函數傳入兩個 trie 樹位於同一相對位置的結點編號,然後合併完成後返回合併完成的結點編號。

過程

考慮怎麼實現?

分三種情況:

  • 如果 a 沒有這個位置上的結點,新合併的結點就是 b
  • 如果 b 沒有這個位置上的結點,新合併的結點就是 a
  • 如果 a,b 都存在,那就把 b 的信息合併到 a 上,新合併的結點就是 a,然後遞歸操作處理 a 的左右兒子。

    提示:如果需要的合併是將 a,b 合併到一棵新樹上,這裏可以新建結點,然後合併到這個新結點上,這裏的代碼實現僅僅是將 b 的信息合併到 a 上。

實現

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
int merge(int a, int b) {
  if (!a) return b;  // 如果 a 沒有這個位置上的結點,返回 b
  if (!b) return a;  // 如果 b 沒有這個位置上的結點,返回 a
  /*
    如果 `a`, `b` 都存在,
    那就把 `b` 的信息合併到 `a` 上。
  */
  w[a] = w[a] + w[b];
  xorv[a] ^= xorv[b];
  /* 不要使用 maintain(),
    maintain() 是合併a的兩個兒子的信息
    而這裏需要 a b 兩個節點進行信息合併
   */
  ch[a][0] = merge(ch[a][0], ch[b][0]);
  ch[a][1] = merge(ch[a][1], ch[b][1]);
  return a;
}

其實 trie 都可以合併,換句話説,trie 合併不僅僅限於 01-trie。

【luogu-P6018】【Ynoi2010】Fusion tree

給你一棵 \(n\) 個結點的樹,每個結點有權值。\(m\) 次操作。 需要支持以下操作。

  • 將樹上與一個節點 \(x\) 距離為 \(1\) 的節點上的權值 \(+1\)。這裏樹上兩點間的距離定義為從一點出發到另外一點的最短路徑上邊的條數。
  • 在一個節點 \(x\) 上的權值 \(-v\)
  • 詢問樹上與一個節點 \(x\) 距離為 \(1\) 的所有節點上的權值的異或和。 對於 \(100\%\) 的數據,滿足 \(1\le n \le 5\times 10^5\)\(1\le m \le 5\times 10^5\)\(0\le a_i \le 10^5\)\(1 \le x \le n\)\(opt\in\{1,2,3\}\)。 保證任意時刻每個節點的權值非負。
題解

每個結點建立一棵 trie 維護其兒子的權值,trie 應該支持全局加一。 可以使用在每一個結點上設置懶標記來標記兒子的權值的增加量。

參考代碼
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <bits/stdc++.h>
using namespace std;
const int _ = 5e5 + 10;

namespace trie {
const int _n = _ * 25;
int rt[_];
int ch[_n][2];
int w[_n];  //`w[o]` 指节点 `o` 到其父亲节点这条边上数值的数量(权值)。
int xorv[_n];
int tot = 0;

void maintain(int o) {  // 维护w数组和xorv(权值的异或)数组
  w[o] = xorv[o] = 0;
  if (ch[o][0]) {
    w[o] += w[ch[o][0]];
    xorv[o] ^= xorv[ch[o][0]] << 1;
  }
  if (ch[o][1]) {
    w[o] += w[ch[o][1]];
    xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1);
  }
}

int mknode() {  // 创造一个新的节点
  ++tot;
  ch[tot][0] = ch[tot][1] = 0;
  w[tot] = 0;
  return tot;
}

void insert(int &o, int x, int dp) {  // x是权重,dp是深度
  if (!o) o = mknode();
  if (dp > 20) return (void)(w[o]++);
  insert(ch[o][x & 1], x >> 1, dp + 1);
  maintain(o);
}

void erase(int o, int x, int dp) {
  if (dp > 20) return (void)(w[o]--);
  erase(ch[o][x & 1], x >> 1, dp + 1);
  maintain(o);
}

void addall(int o) {  // 对所有节点+1即将所有节点的ch[o][1]和ch[o][0]交换
  swap(ch[o][1], ch[o][0]);
  if (ch[o][0]) addall(ch[o][0]);
  maintain(o);
}
}  // namespace trie

int head[_];

struct edges {
  int node;
  int nxt;
} edge[_ << 1];

int tot = 0;

void add(int u, int v) {
  edge[++tot].nxt = head[u];
  head[u] = tot;
  edge[tot].node = v;
}

int n, m;
int rt;
int lztar[_];
int fa[_];

void dfs0(int o, int f) {  // 得到fa数组
  fa[o] = f;
  for (int i = head[o]; i; i = edge[i].nxt) {  // 遍历子节点
    int node = edge[i].node;
    if (node == f) continue;
    dfs0(node, o);
  }
}

int V[_];

int get(int x) { return (fa[x] == -1 ? 0 : lztar[fa[x]]) + V[x]; }  // 权值函数

int main() {
  cin >> n >> m;
  for (int i = 1; i < n; i++) {
    int u, v;
    cin >> u >> v;
    add(u, v);  // 双向建边
    add(rt = v, u);
  }
  dfs0(rt, -1);  // rt是随机的一个点
  for (int i = 1; i <= n; i++) {
    cin >> V[i];
    if (fa[i] != -1) trie::insert(trie::rt[fa[i]], V[i], 0);
  }
  while (m--) {
    int opt, x;
    cin >> opt >> x;
    if (opt == 1) {
      lztar[x]++;
      if (x != rt) {
        if (fa[fa[x]] != -1) trie::erase(trie::rt[fa[fa[x]]], get(fa[x]), 0);
        V[fa[x]]++;
        if (fa[fa[x]] != -1)
          trie::insert(trie::rt[fa[fa[x]]], get(fa[x]), 0);  // 重新插入
      }
      trie::addall(trie::rt[x]);  // 对所有节点+1
    } else if (opt == 2) {
      int v;
      cin >> v;
      if (x != rt) trie::erase(trie::rt[fa[x]], get(x), 0);
      V[x] -= v;
      if (x != rt) trie::insert(trie::rt[fa[x]], get(x), 0);  // 重新插入
    } else {
      int res = 0;
      res = trie::xorv[trie::rt[x]];
      res ^= get(fa[x]);
      printf("%d\n", res);
    }
  }
  return 0;
}
【luogu-P6623】【省選聯考 2020 A 卷】樹

給定一棵 \(n\) 個結點的有根樹 \(T\),結點從 \(1\) 開始編號,根結點為 \(1\) 號結點,每個結點有一個正整數權值 \(v_i\)。 設 \(x\) 號結點的子樹內(包含 \(x\) 自身)的所有結點編號為 \(c_1,c_2,\dots,c_k\),定義 \(x\) 的價值為:
\(val(x)=(v_{c_1}+d(c_1,x)) \oplus (v_{c_2}+d(c_2,x)) \oplus \cdots \oplus (v_{c_k}+d(c_k, x))\) 其中 \(d(x,y)\)
表示樹上 \(x\) 號結點與 \(y\) 號結點間唯一簡單路徑所包含的邊數,\(d(x,x) = 0\)\(\oplus\) 表示異或運算。 請你求出 \(\sum\limits_{i=1}^n val(i)\) 的結果。

題解

考慮每個結點對其所有祖先的貢獻。 每個結點建立 trie,初始先只存這個結點的權值,然後從底向上合併每個兒子結點上的 trie,然後再全局加一,完成後統計答案。

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
const int _ = 526010;
int n;
int V[_];
int debug = 0;

namespace trie {
const int MAXH = 21;
int ch[_ * (MAXH + 1)][2], w[_ * (MAXH + 1)], xorv[_ * (MAXH + 1)];
int tot = 0;

int mknode() {
  ++tot;
  ch[tot][1] = ch[tot][0] = w[tot] = xorv[tot] = 0;
  return tot;
}

void maintain(int o) {
  w[o] = xorv[o] = 0;
  if (ch[o][0]) {
    w[o] += w[ch[o][0]];
    xorv[o] ^= xorv[ch[o][0]] << 1;
  }
  if (ch[o][1]) {
    w[o] += w[ch[o][1]];
    xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1);
  }
  w[o] = w[o] & 1;
}

void insert(int &o, int x, int dp) {
  if (!o) o = mknode();
  if (dp > MAXH) return (void)(w[o]++);
  insert(ch[o][x & 1], x >> 1, dp + 1);
  maintain(o);
}

int merge(int a, int b) {
  if (!a) return b;
  if (!b) return a;
  w[a] = w[a] + w[b];
  xorv[a] ^= xorv[b];
  ch[a][0] = merge(ch[a][0], ch[b][0]);
  ch[a][1] = merge(ch[a][1], ch[b][1]);
  return a;
}

void addall(int o) {
  swap(ch[o][0], ch[o][1]);
  if (ch[o][0]) addall(ch[o][0]);
  maintain(o);
}
}  // namespace trie

int rt[_];
long long Ans = 0;
vector<int> E[_];

void dfs0(int o) {
  for (int i = 0; i < E[o].size(); i++) {
    int node = E[o][i];
    dfs0(node);
    rt[o] = trie::merge(rt[o], rt[node]);
  }
  trie::addall(rt[o]);
  trie::insert(rt[o], V[o], 0);
  Ans += trie::xorv[rt[o]];
}

int main() {
  n = read();
  for (int i = 1; i <= n; i++) V[i] = read();
  for (int i = 2; i <= n; i++) E[read()].push_back(i);
  dfs0(1);
  printf("%lld", Ans);
  return 0;
}

可持久化字典樹

參見 可持久化字典樹