跳转至

可持久化字典樹

引入

可持久化 Trie 的方式和可持久化線段樹的方式是相似的,即每次只修改被添加或值被修改的節點,而保留沒有被改動的節點,在上一個版本的基礎上連邊,使最後每個版本的 Trie 樹的根遍歷所能分離出的 Trie 樹都是完整且包含全部信息的。

大部分的可持久化 Trie 題中,Trie 都是以 01-Trie 的形式出現的。

例題 最大異或和

對一個長度為 \(n\) 的數組 \(a\) 維護以下操作:

  1. 在數組的末尾添加一個數 \(x\),數組的長度 \(n\) 自增 \(1\)
  2. 給出查詢區間 \([l,r]\) 和一個值 \(k\),求當 \(l\le p\le r\) 時,\(k \oplus \bigoplus^{n}_{i=p} a_i\) 的最大值。

過程

這個求的值可能有些麻煩,利用常用的處理連續異或的方法,記 \(s_x=\bigoplus_{i=1}^x a_i\),則原式等價於 \(s_{p-1}\oplus s_n\oplus k\),觀察到 \(s_n \oplus k\) 在查詢的過程中是固定的,題目的查詢變化為查詢在區間 \([l-1,r-1]\) 中異或定值(\(s_n\oplus k\))的最大值。

繼續按類似於可持久化線段樹的思路,考慮每次的查詢都查詢整個區間。我們只需把這個區間建一棵 Trie 樹,將這個區間中的每個樹都加入這棵 Trie 中,查詢的時候,儘量往與當前位不相同的地方跳。

查詢區間,只需要利用前綴和和差分的思想,用兩棵前綴 Trie 樹(也就是按順序添加數的兩個歷史版本)相減即得到該區間的 Trie 樹。再利用動態開點的思想,不添加沒有計算過的點,以減少空間佔用。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
const int maxn = 600010;
int n, q, a[maxn], s[maxn], l, r, x;
char op;

struct Trie {
  int cnt, rt[maxn], ch[maxn * 33][2], val[maxn * 33];

  void insert(int o, int lst, int v) {
    for (int i = 28; i >= 0; i--) {
      val[o] = val[lst] + 1;  // 在原版本的基础上更新
      if ((v & (1 << i)) == 0) {
        if (!ch[o][0]) ch[o][0] = ++cnt;
        ch[o][1] = ch[lst][1];
        o = ch[o][0];
        lst = ch[lst][0];
      } else {
        if (!ch[o][1]) ch[o][1] = ++cnt;
        ch[o][0] = ch[lst][0];
        o = ch[o][1];
        lst = ch[lst][1];
      }
    }
    val[o] = val[lst] + 1;
    // printf("%d\n",o);
  }

  int query(int o1, int o2, int v) {
    int ret = 0;
    for (int i = 28; i >= 0; i--) {
      // printf("%d %d %d\n",o1,o2,val[o1]-val[o2]);
      int t = ((v & (1 << i)) ? 1 : 0);
      if (val[ch[o1][!t]] - val[ch[o2][!t]])
        ret += (1 << i), o1 = ch[o1][!t],
                         o2 = ch[o2][!t];  // 尽量向不同的地方跳
      else
        o1 = ch[o1][t], o2 = ch[o2][t];
    }
    return ret;
  }
} st;

int main() {
  scanf("%d%d", &n, &q);
  for (int i = 1; i <= n; i++) scanf("%d", a + i), s[i] = s[i - 1] ^ a[i];
  for (int i = 1; i <= n; i++)
    st.rt[i] = ++st.cnt, st.insert(st.rt[i], st.rt[i - 1], s[i]);
  while (q--) {
    scanf(" %c", &op);
    if (op == 'A') {
      n++;
      scanf("%d", a + n);
      s[n] = s[n - 1] ^ a[n];
      st.rt[n] = ++st.cnt;
      st.insert(st.rt[n], st.rt[n - 1], s[n]);
    }
    if (op == 'Q') {
      scanf("%d%d%d", &l, &r, &x);
      l--;
      r--;
      if (l == 0)
        printf("%d\n", max(s[n] ^ x, st.query(st.rt[r], st.rt[0], s[n] ^ x)));
      else
        printf("%d\n", st.query(st.rt[r], st.rt[l - 1], s[n] ^ x));
    }
  }
  return 0;
}