跳转至

後綴平衡樹

定義

後綴之間的大小由字典序定義,後綴平衡樹就是一個維護這些後綴順序的平衡樹,即字符串 \(T\) 的後綴平衡樹是 \(T\) 所有後綴的有序集合。後綴平衡樹上的一個節點相當於原字符串的一個後綴。

特別地,後綴平衡樹的中序遍歷即為後綴數組。

構造過程

對長度為 \(n\) 的字符串 \(T\) 建立其後綴平衡樹,考慮逆序將其後綴加入後綴平衡樹。

記後綴平衡樹維護的集合為 \(X\),當前添加的後綴為 \(S\),則添加下一個後綴就是向 \(X\) 中加入 \(\texttt{c}S\)(亦可理解為後綴平衡樹維護的字符串為 \(S\),下一步往 \(S\) 前加入一個字符 \(\texttt{c}\))。這一操作其實就是向平衡樹中插入節點。

這裏使用期望樹高為 \(O(\log n)\) 的平衡樹,例如替罪羊樹或 Treap 等。

做法 1

插入時,暴力比較兩個後綴之間的大小關係,從而判斷之後是往哪一個子樹添加。這樣子,單次插入至多比較 \(O(\log n)\) 次,單次比較的時間複雜度至多為 \(O(n)\),一共 \(O(n\log n)\)

一共會插入 \(n\) 次,所以該做法的時間複雜度存在上界 \(O(n^2 \log n)\)

做法 2

注意到 \(\texttt{c}S\)\(S\) 的區別僅在於 \(\texttt{c}\),且 \(S\) 已經屬於 \(X\) 了,可以利用這一點來優化插入操作。

假設當前要比較 \(\texttt{c}S\)\(A\) 兩個字符串的大小,且 \(A, S \in X\)。每次比較時,首先比較兩串的首字符。若首字符不等,則兩串的大小關係就已經確定了;若首字符相等,那麼就只需要判斷去除首字符後兩字符串的大小關係。而兩串去除首字符後都已經屬於 \(X\) 了,這時候可以藉助平衡樹 \(O(\log n)\) 求排名的操作來完成後續的比較。這樣,單次插入的操作至多 \(O(\log^2 n)\)

一共會插入 \(n\) 次,所以該做法的時間複雜度存在上界 \(O(n \log^2 n)\)

做法 3

根據做法 2,如果能夠 \(O(1)\) 判斷平衡樹中兩個節點之間的大小關係,那麼就可以在 \(O(n \log n)\) 的時間內完成後綴平衡樹的構造。

\(val_i\) 表示節點 \(i\) 的值。如果在建平衡樹時,每個節點多維護一個標記 \(tag_i\),使得若 \(tag_i > tag_j \iff val_i > val_j\),那麼就可以根據 \(tag_i\) 的大小 \(O(1)\) 判斷平衡樹中兩個節點的大小。

不妨令平衡樹中每個節點對應一個實數區間,令根節點對應 \((0, 1)\)。對於節點 \(i\),記其對應的實數區間為 \((l, r)\),則 \(tag_i = \frac{l + r}{2}\),其左子樹對應實數區間 \((l, tag_i)\),其右子樹對應實數區間 \((tag_i, r)\)。易證 \(tag_i\) 滿足上述要求。

由於使用了期望樹高為 \(O(\log n)\) 的平衡樹,所以精度是有一定保證的。實際實現時也可以用一個較大的區間來做,例如讓根對應 \((0, 10^{18})\)

做法 4

其實可以先構建出後綴數組,然後再根據後綴數組構建後綴平衡樹。這樣做的複雜度瓶頸在於後綴數組的構建複雜度或者所用平衡樹一次性插入 \(n\) 個元素的複雜度。

刪除操作

假設當前添加的後綴為 \(\texttt{c}S\),上一個添加的後綴為 \(S\)。後綴平衡樹還支持刪除後綴 \(\texttt{c}S\) 的操作(亦可理解為後綴平衡樹維護的字符串為 \(\texttt{c}S\),將開頭的 \(\texttt{c}\) 刪除)。

類似於插入操作,藉助平衡樹的刪除節點操作可以完成刪除 \(\texttt{c}S\) 的操作。

後綴平衡樹的優點

  • 後綴平衡樹的思路比較清晰,相比後綴自動機等後綴結構更好理解,會寫平衡樹就能寫。
  • 後綴平衡樹的複雜度不依賴於字符集的大小
  • 後綴平衡樹支持在字符串開頭刪除一個字符
  • 如果使用支持可持久化的平衡樹,那麼後綴平衡樹也能可持久化

例題

P3809【模板】後綴排序

後綴數組的模板題,建出後綴平衡樹之後,通過中序遍歷得到後綴數組。

SGT 版本的參考代碼
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
#include <bits/stdc++.h>
using namespace std;

const int N = 1e6 + 5;
const double INF = 1e18;

int n, m, sa[N];
char t[N];

// SuffixBST(SGT Ver)

// 顺序加入,查询时将询问串翻转
// 以i开始的后缀,对应节点的编号为i
const double alpha = 0.75;
int root;
int sz[N], L[N], R[N];
double tag[N];
int buffer_size, buffer[N];

bool cmp(int x, int y) {
  if (t[x] != t[y]) return t[x] < t[y];
  return tag[x + 1] < tag[y + 1];
}

void init() { root = 0; }

void new_node(int& rt, int p, double lv, double rv) {
  rt = p;
  sz[rt] = 1;
  tag[rt] = (lv + rv) / 2;
  L[rt] = R[rt] = 0;
}

void push_up(int x) {
  if (!x) return;
  sz[x] = sz[L[x]] + 1 + sz[R[x]];
}

bool balance(int rt) { return alpha * sz[rt] > max(sz[L[rt]], sz[R[rt]]); }

void flatten(int rt) {
  if (!rt) return;
  flatten(L[rt]);
  buffer[++buffer_size] = rt;
  flatten(R[rt]);
}

void build(int& rt, int l, int r, double lv, double rv) {
  if (l > r) {
    rt = 0;
    return;
  }
  int mid = (l + r) >> 1;
  double mv = (lv + rv) / 2;

  rt = buffer[mid];
  tag[rt] = mv;
  build(L[rt], l, mid - 1, lv, mv);
  build(R[rt], mid + 1, r, mv, rv);
  push_up(rt);
}

void rebuild(int& rt, double lv, double rv) {
  buffer_size = 0;
  flatten(rt);
  build(rt, 1, buffer_size, lv, rv);
}

void insert(int& rt, int p, double lv, double rv) {
  if (!rt) {
    new_node(rt, p, lv, rv);
    return;
  }

  if (cmp(p, rt))
    insert(L[rt], p, lv, tag[rt]);
  else
    insert(R[rt], p, tag[rt], rv);

  push_up(rt);
  if (!balance(rt)) rebuild(rt, lv, rv);
}

void inorder(int rt) {
  if (!rt) return;
  inorder(L[rt]);
  sa[++m] = rt;
  inorder(R[rt]);
}

void solve(int Case) {
  scanf("%s", t + 1);
  n = strlen(t + 1);

  init();
  for (int i = n; i >= 1; --i) {
    insert(root, i, 0, INF);
  }

  // 后缀平衡树的中序遍历即为后缀数组
  m = 0;
  inorder(root);

  for (int i = 1; i <= n; ++i) printf("%d ", sa[i]);
  printf("\n");
}

int main() {
  solve(1);
  return 0;
}

P6164【模板】後綴平衡樹

題意

給定初始字符串 \(s\)\(q\) 個操作:

  1. 在當前字符串的後面插入若干個字符。
  2. 在當前字符串的後面刪除若干個字符。
  3. 詢問字符串 \(t\) 作為連續子串在當前字符串中出現了幾次?

題目 強制在線,字符串變化長度以及初始長度 \(\le 8 \times 10^5\)\(q \le 10^5\),詢問的總長度 \(\le 3 \times 10^6\)

對於操作 1 和操作 2,由於後綴平衡樹維護頭插和頭刪操作比較方便,所以想到把尾插和尾刪操作搞成頭插和頭刪。這裏如果維護 \(s\) 的反串的後綴平衡樹,而非 \(s\) 的後綴平衡樹,就可以完成上述轉換。平衡樹的添加和刪除都是 \(O(\log n)\) 的,所以添加或者刪除一個字符的時間複雜度為 \(O(\log n)\)。記添加和刪除的總字符數為 \(N\),那麼這一部分總的時間複雜度為 \(O(N \log n)\)

對於操作 3,\(t\) 的出現次數等於以 \(t\) 為前綴的後綴數量,而以 \(t\) 為前綴的後綴數量等於其後繼的排名減去其前驅的排名。在 \(t\) 後面加入一個極大的字符,就可以構造出 \(t\) 的一個後繼。將 \(t\) 的最後一個字符減小 1,就可以構造出 \(t\) 的一個前驅。

現在要查詢某一個串 \(t\) 在後綴平衡樹中排名,由於不能保證 \(t\) 在後綴平衡樹中出現過,所以每次只能暴力比較字符串大小。單次比較的時間複雜度為 \(O(|t|)\),每次查詢至多比較 \(O(\log n)\) 次,所以單次查詢的複雜度為 \(O(|t|\log n)\)。記所有詢問串的長度和為 \(L\),那麼這一部分總的時間複雜度為 \(O(L \log n)\)

SGT 版本的參考代碼
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#include <bits/stdc++.h>
using namespace std;

const int N = 8e5 + 5;
const double INF = 1e18;

void decode(char* s, int len, int mask) {
  for (int i = 0; i < len; ++i) {
    mask = (mask * 131 + i) % len;
    swap(s[i], s[mask]);
  }
}

int q, n, na;
char a[N], t[N];

// SuffixBST(SGT Ver)

// 顺序加入,查询时将询问串翻转
// 以i结束的前缀,对应节点的编号为i
// 注意:不能写懒惰删除,否则可能会破坏树的结构
const double alpha = 0.75;
int root;
int sz[N], L[N], R[N];
double tag[N];
int buffer_size, buffer[N];

bool cmp(int x, int y) {
  if (t[x] != t[y]) return t[x] < t[y];
  return tag[x - 1] < tag[y - 1];
}

void init() { root = 0; }

void new_node(int& rt, int p, double lv, double rv) {
  rt = p;
  sz[rt] = 1;
  tag[rt] = (lv + rv) / 2;
  L[rt] = R[rt] = 0;
}

void push_up(int x) {
  if (!x) return;
  sz[x] = sz[L[x]] + 1 + sz[R[x]];
}

bool balance(int rt) { return alpha * sz[rt] > max(sz[L[rt]], sz[R[rt]]); }

void flatten(int rt) {
  if (!rt) return;
  flatten(L[rt]);
  buffer[++buffer_size] = rt;
  flatten(R[rt]);
}

void build(int& rt, int l, int r, double lv, double rv) {
  if (l > r) {
    rt = 0;
    return;
  }
  int mid = (l + r) >> 1;
  double mv = (lv + rv) / 2;

  rt = buffer[mid];
  tag[rt] = mv;
  build(L[rt], l, mid - 1, lv, mv);
  build(R[rt], mid + 1, r, mv, rv);
  push_up(rt);
}

void rebuild(int& rt, double lv, double rv) {
  buffer_size = 0;
  flatten(rt);
  build(rt, 1, buffer_size, lv, rv);
}

void insert(int& rt, int p, double lv, double rv) {
  if (!rt) {
    new_node(rt, p, lv, rv);
    return;
  }

  if (cmp(p, rt))
    insert(L[rt], p, lv, tag[rt]);
  else
    insert(R[rt], p, tag[rt], rv);

  push_up(rt);
  if (!balance(rt)) rebuild(rt, lv, rv);
}

void remove(int& rt, int p, double lv, double rv) {
  if (!rt) return;

  if (rt == p) {
    if (!L[rt] || !R[rt]) {
      rt = (L[rt] | R[rt]);
      rebuild(rt, lv, rv);
    } else {
      // 找到rt的前驱来替换rt
      int nrt = L[rt];
      while (R[nrt]) {
        nrt = R[nrt];
      }
      remove(L[rt], nrt, lv, tag[rt]);
      L[nrt] = L[rt];
      R[nrt] = R[rt];
      rt = nrt;
      tag[rt] = (lv + rv) / 2;
    }
  } else {
    double mv = (lv + rv) / 2;
    if (cmp(p, rt))
      remove(L[rt], p, lv, mv);
    else
      remove(R[rt], p, mv, rv);
  }

  push_up(rt);
  if (!balance(rt)) rebuild(rt, lv, rv);
}

bool cmp1(char* s, int len, int p) {
  for (int i = 1; i <= len; ++i, --p) {
    if (s[i] < t[p]) return true;
    if (s[i] > t[p]) return false;
  }
  return false;
}

int query(int rt, char* s, int len) {
  if (!rt) return 0;
  if (cmp1(s, len, rt))
    return query(L[rt], s, len);
  else
    return sz[L[rt]] + 1 + query(R[rt], s, len);
}

void solve() {
  n = 0;
  scanf("%d", &q);
  init();

  scanf("%s", a + 1);
  na = strlen(a + 1);
  for (int i = 1; i <= na; ++i) {
    t[++n] = a[i];
    insert(root, n, 0, INF);
  }

  int mask = 0;
  char op[10];
  for (int i = 1; i <= q; ++i) {
    scanf("%s", op);

    // 三种情况分别处理

    if (op[0] == 'A') {  // ADD
      scanf("%s", a + 1);
      na = strlen(a + 1);
      decode(a + 1, na, mask);

      for (int i = 1; i <= na; ++i) {
        t[++n] = a[i];
        insert(root, n, 0, INF);
      }
    } else if (op[0] == 'D') {  // DEL
      int x;
      scanf("%d", &x);
      while (x) {
        remove(root, n, 0, INF);
        --n;
        --x;
      }
    } else if (op[0] == 'Q') {  // QUERY
      scanf("%s", a + 1);
      na = strlen(a + 1);
      decode(a + 1, na, mask);

      reverse(a + 1, a + 1 + na);

      a[na + 1] = 'Z' + 1;
      a[na + 2] = 0;
      int ans = query(root, a, na + 1);

      --a[na];
      ans -= query(root, a, na + 1);

      printf("%d\n", ans);
      mask ^= ans;
    }
  }
}

int main() {
  solve();
  return 0;
}

參考資料

  • 陳立傑 -《重量平衡樹和後綴平衡樹在信息學奧賽中的應用》