跳转至

WBLT

前言

Weight Balanced Leafy Tree,下稱 WBLT,是一種平衡樹,比起其它平衡樹主要有實現簡單、常數小的優點。

Weight Balanced Leafy Tree 顧名思義是 Weight Balanced Tree 和 Leafy Tree 的結合。

Weight Balanced Tree 的每個結點儲存這個結點下子樹的大小,並且通過保持左右子樹的大小關係在一定範圍來保證樹高。

Leafy Tree 維護的原始信息僅存儲在樹的 葉子節點 上,而非葉子節點僅用於維護子節點信息和維持數據結構的形態。我們熟知的線段樹就是一種 Leafy Tree。

平衡樹基礎操作

代碼約定

下文中,我們用 ls[x] 表示節點 \(x\) 的左兒子,rs[x] 表示節點 \(x\) 的右兒子,vl[x] 表示節點 \(x\) 的權值,sz[x] 表示節點 \(x\) 及其子樹中葉子節點的個數。

建樹

正如前言中所説的,WBLT 的原始信息僅存儲在葉子節點上。而我們規定每個非葉子節點一定有兩個子節點,這個節點要維護其子節點信息的合併。同時,每個節點還要維護自身及其子樹中葉子節點的數量,用於實現維護平衡。

和大多數的平衡樹一樣,每個非葉子節點的右兒子的權值大於等於左兒子的權值,且在 WBLT 中非葉子節點節點的權值等於右兒子的權值。不難看出每個節點的權值就是其子樹中的最大權值。

這樣聽起來就很像一棵維護區間最大值的動態開點線段樹了,且所有葉子從左到右是遞增的。事實上的建樹操作也與線段樹十分相似,只需要向下遞歸,直至區間長度為 \(1\) 時把要維護的信息放葉子節點上,回溯的時候合併區間信息即可。

代碼實現如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/* 添加一個權值為 v 的節點,返回這個節點的編號 */
int add(int v) {
  ++cnt;
  ls[cnt] = rs[cnt] = 0;
  sz[cnt] = 1;
  vl[cnt] = v;
  return cnt;
}

/* 更新節點編號為 x 的節點的信息 */
void pushup(int x) {
  vl[x] = vl[rs[x]];
  sz[x] = sz[ls[x]] + sz[rs[x]];
}

/* 遞歸建樹 */
int build(int l, int r) {
  if (l == r) {
    return add(a[l]);
  }
  int x = add(0);
  int k = l + ((r - l) >> 1);
  ls[x] = build(l, k);
  rs[x] = build(k + 1, r);
  pushup(x);
}

插入和刪除

由於 WBLT 的信息都存儲在葉子節點上,插入和刪除一個元素其實就是增加或減少了一個葉子節點。

對於插入操作,我們類似從根節點開始向下遞歸,直到找到權值大於等於插入元素的權值最小的葉子節點,再新建兩個節點,其中一個用來存儲新插入的值,另一個作為兩個葉子的新父親替代這個最小葉子節點的位置,再將這兩個葉子連接到這個父親上。

例如我們向以下樹中加入一個值為 \(4\) 的元素。

wblt-1

我們首先找到了葉子節點 \(5\),隨後新建了一個非葉子節點 \(D\),並將 \(4\)\(5\) 連接到了 \(D\) 上。

wblt-2

對於刪除,我們考慮上面過程的逆過程。即找到與要刪除的值權值相等的一個葉子節點,將它和它的父親節點刪除,並用其父親的另一個兒子代替父親的位置。

上面提到的建樹也可以通過不斷往樹裏插入節點實現,不過如果這樣做必須要加入一個權值為 \(\infty\)⁡ 的節點作為根,否則會導致插入第一個元素的時候找不到大於自己的葉子節點。

代碼實現:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
/* 將某一節點的全部信息複製到另一節點上 */
void copynode(int x, int y) {
  ls[x] = ls[y];
  rs[x] = rs[y];
  sz[x] = sz[y];
  vl[x] = vl[y];
}

/* 判斷某一節點是否為葉子節點 */
bool leaf(int x) { return !ls[x] || !rs[x]; }

void insert(int v) {
  if (leaf(x)) {
    ls[x] = add(std::min(v, vl[x]));
    rs[x] = add(std::max(v, vl[x]));
    pushup(x);
    maintain(x);
    return;
  }
  if (vl[ls[x]] >= v) {
    insert(ls[x], v);
  } else {
    insert(rs[x], v);
  }
  pushup(x);
  maintain(x);
}

void delete(int x, int v, int fa) {
  if (leaf(x)) {
    if (ls[fa] == x) {
      copynode(fa, rs[fa]);
    } else {
      copynode(fa, ls[fa]);
    }
    pushup(fa);
    return;
  }
  if (vl[ls[x]] >= v) {
    delete (ls[x], v, x);
  } else {
    delete (rs[x], v, x);
  }
  pushup(x);
  maintain(x);
}

維護平衡

類似替罪羊樹地,我們引入重構參數 \(\alpha \in (0, \dfrac{1}{2}]\),我們設一個節點的平衡度 \(\rho\) 為當前節點左子樹大小和節點大小的比值。當一個節點滿足 \(\rho \in[\alpha, 1-\alpha]\) 時,我們稱其為 \(\alpha\)- 平衡的。如果一棵 WBLT 的每一個節點都是 \(\alpha\)- 平衡的,那麼這棵樹的樹高一定能保證是 \(O(\log n)\) 量級的。證明是顯然的,我們從一個葉子節點往父親方向走,每次走到的節點維護的範圍至少擴大到原來的 \(\dfrac{1}{1 - \alpha}\) 倍,那麼樹高就是 \(O(\log_{\frac{1}{1-\alpha}}n) = O(\log n)\) 量級的。

當某個節點不滿足 \(\alpha\)- 平衡時,説明這個節點是失衡的,我們需要重新維護平衡。但是和替罪羊樹不同的是,WBLT 使用旋轉操作維護平衡。旋轉的大致過程為:將過重的兒子的兩個兒子拆下來,一個和過輕的兒子合併,另一個成為一個新的兒子。

我們來舉個例子:

wblt-3

這是一棵十分不平衡的 WBLT,節點 \(A\) 的右兒子顯著地重於左兒子。我們先把右兒子及其兩個兒子和左兒子都拆下來:

wblt-4

然後,我們將 \(1\)\(2\) 兩個節點合併作為 \(A\) 節點的左兒子,將 \(C\) 作為 \(A\) 的右兒子。由於 \(B\) 節點原本並不是葉子節點,因此其並不存儲原始信息,直接刪除就好。

wblt-5

旋轉之後我們的樹就變得十分平衡了。

但是上面的例子中,假設 \(A\) 節點的左子樹過於大,我們把它合併到 \(A\) 的左子樹上之後 \(A\) 的左子樹又會很大,這時 \(A\) 依然可能不平衡。

wblt-6

不失一般性,我們接下來僅討論一個方向上的旋轉,另一方向的旋轉是對稱的。我們不妨設 A 的平衡度為 \(\rho_1\)B 的平衡度為 \(\rho_2\)。那麼我們可以得到旋轉後 A 的平衡度 \(\gamma_1= \rho_1+(1 - \rho_1)\rho_2\)B 的平衡度 \(\gamma_2 =\dfrac{\rho_1}{\rho_1 + (1 - \rho_1)\rho_2}\),推導過程直接將各節點大小用 \(siz_A\) 表示後代入定義式化簡即可,這裏略去。

不難發現僅當 \(\rho_2 \le \dfrac{1 - 2\alpha}{1 - \alpha}\)\(\gamma_1, \gamma_2 \in [\alpha, 1 - \alpha]\)

為了旋轉後仍不平衡的情況出現,我們引入雙旋操作。具體地,我們在較大子樹上做一次相反方向的旋轉操作,然後再維護當前節點的平衡。

wblt-7

類似地定義 \(\rho_3,\gamma_3\),則有 \(\gamma_1=\rho_1+\rho_2\rho_3(1-\rho_1), \gamma_2=\dfrac{\rho_1}{\rho1+(1 - \rho_1)\rho2\rho3}, \gamma_3 = \dfrac{\rho_2(1-\rho_3)}{1-\rho_2\rho_3}\)。可以證明當 \(\alpha < 1- \dfrac{\sqrt2}{2} \approx 0.292\) 時一定有 \(\gamma_1, \gamma_2, \gamma_3 \in [\alpha, 1 - \alpha]\)

實現上,我們在 \(\rho_2 \le \dfrac{1 - 2\alpha}{1 - \alpha}\) 時進行單旋,否則進行雙旋。

代碼實現,這裏取 \(\alpha = 0.25\)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
const double alpha = 0.25;

void rotate(int x, int flag) {
  if (!flag) {
    rs[x] = merge(rs[ls[x]], rs[x]);
    ls[x] = ls[ls[x]];
  } else {
    ls[x] = merge(ls[x], ls[rs[x]]);
    rs[x] = rs[rs[x]];
  }
}

void maintain(int x) {
  if (sz[ls[x]] > sz[rs[x]] * 3) {
    if (sz[rs[ls[x]]] > sz[ls[ls[x]]] * 2) {
      rotate(ls[x], 1);
    }
    rotate(x, 0);
  } else if (sz[rs[x]] > sz[ls[x]] * 3) {
    if (sz[ls[rs[x]]] > sz[rs[rs[x]]] * 2) {
      rotate(rs[x], 0);
    }
    rotate(x, 1);
  }
}

查詢排名

我們發現 WBLT 的形態和線段樹十分相似,因此查詢排名可以使用類似線段樹上二分的方式:如果左子樹的最大值比大於等於待查值就往左兒子跳,否則就向右跳,同時答案加上左子樹的 size

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
int rank(int x, int v) {
  if (leaf(x)) {
    return 1;
  }
  if (vl[ls[x]] >= v) {
    return rank(ls[x], v);
  } else {
    return rank(rs[x], v) + sz[ls[x]];
  }
}

查詢第 k 大的數

依然是利用線段樹上二分的思想,只不過這裏比較的是節點的大小。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
int kth(int x, int v) {
  if (sz[x] == v) {
    return vl[x];
  }
  if (sz[ls[x]] >= v) {
    return kth(ls[x], v);
  } else {
    return kth(rs[x], v - sz[ls[x]]);
  }
}

總結

以上,我們利用 WBLT 完成了平衡樹基本的幾大操作。下面是用 WBLT 實現的 普通平衡樹模板

完整代碼
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include <bits/stdc++.h>

typedef long long ll;

const ll MAX = 2e6 + 5;
const ll INF = 0x7fffffff;

ll ans, lst, n, m, t, op, rt, cnt;
ll ls[MAX], rs[MAX], vl[MAX], sz[MAX];

void cp(ll x, ll y) {
  ls[x] = ls[y];
  rs[x] = rs[y];
  sz[x] = sz[y];
  vl[x] = vl[y];
}

ll add(ll v, ll s, ll l, ll r) {
  ++cnt;
  ls[cnt] = l;
  rs[cnt] = r;
  sz[cnt] = s;
  vl[cnt] = v;
  return cnt;
}

ll merge(ll x, ll y) { return add(vl[y], sz[x] + sz[y], x, y); }

void upd(ll x) {
  if (!ls[x]) {
    sz[x] = 1;
    return;
  }
  sz[x] = sz[ls[x]] + sz[rs[x]];
  vl[x] = vl[rs[x]];
}

void rot(int x, int flag) {
  if (!flag) {
    rs[x] = merge(rs[ls[x]], rs[x]);
    ls[x] = ls[ls[x]];
  } else {
    ls[x] = merge(ls[x], ls[rs[x]]);
    rs[x] = rs[rs[x]];
  }
}

void mat(int x) {
  if (sz[ls[x]] > sz[rs[x]] * 3) {
    if (sz[rs[ls[x]]] > sz[ls[ls[x]]] * 2) {
      rot(ls[x], 1);
    }
    rot(x, 0);
  } else if (sz[rs[x]] > sz[ls[x]] * 3) {
    if (sz[ls[rs[x]]] > sz[rs[rs[x]]] * 2) {
      rot(rs[x], 0);
    }
    rot(x, 1);
  }
}

void ins(ll x, ll v) {
  if (!ls[x]) {
    ls[x] = add(std::min(v, vl[x]), 1, 0, 0);
    rs[x] = add(std::max(v, vl[x]), 1, 0, 0);
    upd(x);
    mat(x);
    return;
  }
  if (vl[ls[x]] >= v) {
    ins(ls[x], v);
  } else {
    ins(rs[x], v);
  }
  upd(x);
  mat(x);
  return;
}

void del(ll x, ll v, ll fa) {
  if (!ls[x]) {
    if (vl[ls[fa]] == v) {
      cp(fa, rs[fa]);
    } else if (vl[rs[fa]] == v) {
      cp(fa, ls[fa]);
    }
    return;
  }
  if (vl[ls[x]] >= v) {
    del(ls[x], v, x);
  } else {
    del(rs[x], v, x);
  }
  upd(x);
  mat(x);
  return;
}

ll rnk(ll x, ll v) {
  if (sz[x] == 1) {
    return 1;
  }
  if (vl[ls[x]] >= v) {
    return rnk(ls[x], v);
  } else {
    return rnk(rs[x], v) + sz[ls[x]];
  }
}

ll kth(ll x, ll v) {
  if (sz[x] == v) {
    return vl[x];
  }
  if (sz[ls[x]] >= v) {
    return kth(ls[x], v);
  } else {
    return kth(rs[x], v - sz[ls[x]]);
  }
}

ll pre(ll x) { return kth(rt, rnk(rt, x) - 1); }

ll nxt(ll x) { return kth(rt, rnk(rt, x + 1)); }

int main() {
  scanf("%lld", &m);
  rt = add(INF, 1, 0, 0);
  while (m--) {
    scanf("%lld%lld", &op, &t);
    if (op == 1) {
      ins(rt, t);
    } else if (op == 2) {
      del(rt, t, -1);
    } else if (op == 3) {
      printf("%lld\n", rnk(rt, t));
    } else if (op == 4) {
      printf("%lld\n", kth(rt, t));
    } else if (op == 5) {
      printf("%lld\n", pre(t));
    } else {
      printf("%lld\n", nxt(t));
    }
  }
  return 0;
}