跳转至

樹分治

點分治

點分治適合處理大規模的樹上路徑信息問題。

例題 1 Luogu P3806【模板】點分治 1

給定一棵有 \(n\) 個點的帶邊權樹,\(m\) 次詢問,每次詢問給出 \(k\),詢問樹上距離為 \(k\) 的點對是否存在。

\(n\le 10000,m\le 100,k\le 10000000\)

我們先隨意選擇一個節點作為根節點 \(\mathit{rt}\),所有完全位於其子樹中的路徑可以分為兩種,一種是經過當前根節點的路徑,一種是不經過當前根節點的路徑。對於經過當前根節點的路徑,又可以分為兩種,一種是以根節點為一個端點的路徑,另一種是兩個端點都不為根節點的路徑。而後者又可以由兩條屬於前者鏈合併得到。所以,對於枚舉的根節點 \(rt\),我們先計算在其子樹中且經過該節點的路徑對答案的貢獻,再遞歸其子樹對不經過該節點的路徑進行求解。

在本題中,對於經過根節點 \(\mathit{rt}\) 的路徑,我們先枚舉其所有子節點 \(\mathit{ch}\),以 \(\mathit{ch}\) 為根計算 \(\mathit{ch}\) 子樹中所有節點到 \(\mathit{rt}\) 的距離。記節點 \(i\) 到當前根節點 \(rt\) 的距離為 \(\mathit{dist}_i\)\(\mathit{tf}_{d}\) 表示之前處理過的子樹中是否存在一個節點 \(v\) 使得 \(\mathit{dist}_v=d\)。若一個詢問的 \(k\) 滿足 \(tf_{k-\mathit{dist}_i}=true\),則存在一條長度為 \(k\) 的路徑。在計算完 \(\mathit{ch}\) 子樹中所連的邊能否成為答案後,我們將這些新的距離加入 \(\mathit{tf}\) 數組中。

注意在清空 \(\mathit{tf}\) 數組的時候不能直接用 memset,而應將之前佔用過的 \(\mathit{tf}\) 位置加入一個隊列中,進行清空,這樣才能保證時間複雜度。

點分治過程中,每一層的所有遞歸過程合計對每個點處理一次,假設共遞歸 \(h\) 層,則總時間複雜度為 \(O(hn)\)

若我們 每次選擇子樹的重心作為根節點,可以保證遞歸層數最少,時間複雜度為 \(O(n\log n)\)

請注意在重新選擇根節點之後一定要重新計算子樹的大小,否則一點看似微小的改動就可能會使時間複雜度錯誤或正確性難以保證。

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
const int maxn = 20010;
const int inf = 2e9;
int n, m, a, b, c, q[maxn], rt, siz[maxn], maxx[maxn], dist[maxn];
int cur, h[maxn], nxt[maxn], p[maxn], w[maxn];
bool tf[10000010], ret[maxn], vis[maxn];

void add_edge(int x, int y, int z) {
  cur++;
  nxt[cur] = h[x];
  h[x] = cur;
  p[cur] = y;
  w[cur] = z;
}

int sum;

void calcsiz(int x, int fa) {
  siz[x] = 1;
  maxx[x] = 0;
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) {
      calcsiz(p[j], x);
      maxx[x] = max(maxx[x], siz[p[j]]);
      siz[x] += siz[p[j]];
    }
  maxx[x] = max(maxx[x], sum - siz[x]);
  if (maxx[x] < maxx[rt]) rt = x;
}

int dd[maxn], cnt;

void calcdist(int x, int fa) {
  dd[++cnt] = dist[x];
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]])
      dist[p[j]] = dist[x] + w[j], calcdist(p[j], x);
}

queue<int> tag;

void dfz(int x, int fa) {
  tf[0] = true;
  tag.push(0);
  vis[x] = true;
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) {
      dist[p[j]] = w[j];
      calcdist(p[j], x);
      for (int k = 1; k <= cnt; k++)
        for (int i = 1; i <= m; i++)
          if (q[i] >= dd[k]) ret[i] |= tf[q[i] - dd[k]];
      for (int k = 1; k <= cnt; k++)
        if (dd[k] < 10000010) tag.push(dd[k]), tf[dd[k]] = true;
      cnt = 0;
    }
  while (!tag.empty()) tf[tag.front()] = false, tag.pop();
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) {
      sum = siz[p[j]];
      rt = 0;
      maxx[rt] = inf;
      calcsiz(p[j], x);
      calcsiz(rt, -1);
      dfz(rt, x);
    }
}

int main() {
  scanf("%d%d", &n, &m);
  for (int i = 1; i < n; i++)
    scanf("%d%d%d", &a, &b, &c), add_edge(a, b, c), add_edge(b, a, c);
  for (int i = 1; i <= m; i++) scanf("%d", q + i);
  rt = 0;
  maxx[rt] = inf;
  sum = n;
  calcsiz(1, -1);
  calcsiz(rt, -1);
  dfz(rt, -1);
  for (int i = 1; i <= m; i++)
    if (ret[i])
      printf("AYE\n");
    else
      printf("NAY\n");
  return 0;
}
例題 2 Luogu P4178 Tree

給定一棵有 \(n\) 個點的帶權樹,給出 \(k\),詢問樹上距離小於等於 \(k\) 的點對數量。

\(n\le 40000,k\le 20000,w_i\le 1000\)

由於這裏查詢的是樹上距離為 \([0,k]\) 的點對數量,所以我們用線段樹來支持維護和查詢。

參考代碼
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
#define int long long
using namespace std;
const int maxn = 2000010;
const int inf = 2e9;
int n, a, b, c, q, rt, siz[maxn], maxx[maxn], dist[maxn];
int cur, h[maxn], nxt[maxn], p[maxn], w[maxn], ret;
bool vis[maxn];

void add_edge(int x, int y, int z) {
  cur++;
  nxt[cur] = h[x];
  h[x] = cur;
  p[cur] = y;
  w[cur] = z;
}

int sum;

void calcsiz(int x, int fa) {
  siz[x] = 1;
  maxx[x] = 0;
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) {
      calcsiz(p[j], x);
      maxx[x] = max(maxx[x], siz[p[j]]);
      siz[x] += siz[p[j]];
    }
  maxx[x] = max(maxx[x], sum - siz[x]);
  if (maxx[x] < maxx[rt]) rt = x;
}

int dd[maxn], cnt;

void calcdist(int x, int fa) {
  dd[++cnt] = dist[x];
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]])
      dist[p[j]] = dist[x] + w[j], calcdist(p[j], x);
}

queue<int> tag;

struct segtree {
  int cnt, rt, lc[maxn], rc[maxn], sum[maxn];

  void clear() {
    while (!tag.empty()) update(rt, 1, 20000000, tag.front(), -1), tag.pop();
    cnt = 0;
  }

  void print(int o, int l, int r) {
    if (!o || !sum[o]) return;
    if (l == r) {
      printf("%lld %lld\n", l, sum[o]);
      return;
    }
    int mid = (l + r) >> 1;
    print(lc[o], l, mid);
    print(rc[o], mid + 1, r);
  }

  void update(int& o, int l, int r, int x, int v) {
    if (!o) o = ++cnt;
    if (l == r) {
      sum[o] += v;
      if (!sum[o]) o = 0;
      return;
    }
    int mid = (l + r) >> 1;
    if (x <= mid)
      update(lc[o], l, mid, x, v);
    else
      update(rc[o], mid + 1, r, x, v);
    sum[o] = sum[lc[o]] + sum[rc[o]];
    if (!sum[o]) o = 0;
  }

  int query(int o, int l, int r, int ql, int qr) {
    if (!o) return 0;
    if (r < ql || l > qr) return 0;
    if (ql <= l && r <= qr) return sum[o];
    int mid = (l + r) >> 1;
    return query(lc[o], l, mid, ql, qr) + query(rc[o], mid + 1, r, ql, qr);
  }
} st;

void dfz(int x, int fa) {
  // tf[0]=true;tag.push(0);
  st.update(st.rt, 1, 20000000, 1, 1);
  tag.push(1);
  vis[x] = true;
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) {
      dist[p[j]] = w[j];
      calcdist(p[j], x);
      for (int k = 1; k <= cnt; k++)
        if (q - dd[k] >= 0)
          ret += st.query(st.rt, 1, 20000000, max(0ll, 1 - dd[k]) + 1,
                          max(0ll, q - dd[k]) + 1);
      for (int k = 1; k <= cnt; k++)
        st.update(st.rt, 1, 20000000, dd[k] + 1, 1), tag.push(dd[k] + 1);
      cnt = 0;
    }
  st.clear();
  for (int j = h[x]; j; j = nxt[j])
    if (p[j] != fa && !vis[p[j]]) {
      sum = siz[p[j]];
      rt = 0;
      maxx[rt] = inf;
      calcsiz(p[j], x);
      calcsiz(rt, -1);
      dfz(rt, x);
    }
}

signed main() {
  scanf("%lld", &n);
  for (int i = 1; i < n; i++)
    scanf("%lld%lld%lld", &a, &b, &c), add_edge(a, b, c), add_edge(b, a, c);
  scanf("%lld", &q);
  rt = 0;
  maxx[rt] = inf;
  sum = n;
  calcsiz(1, -1);
  calcsiz(rt, -1);
  dfz(rt, -1);
  printf("%lld\n", ret);
  return 0;
}
例題 3 Luogu P2664 樹上游戲

一棵每個節點都給定顏色的樹,定義 \(s(i,j)\)\(\mathit{i}\)\(\mathit{j}\) 的顏色數量,\(\mathit{sum_{i}}=\sum_{j=1}^n s(i,j)\)。對所有的 \(1\leq i\leq n\),求 \(sum_i\)。(\(1 \le n, c_i \le 10^5\)

這道題很考驗對點分治思想的理解和應用,適合作為點分治的難度較高的例題和練習題。

首先,我們需要想明白一個轉化。題目定義 \(\mathit{sum_i}\)\(i\) 到所有節點路徑上的顏色數量之和,可是如果用這個方法,在點分治中是不好統計答案的,因為這樣很難合併從當前根出發的兩棵子樹的信息。所以我們想到將 \(\mathit{sum_i}\) 的意義轉化。對於每個顏色 \(j\), 其中一個端點為 \(i\) 且含有顏色 \(j\) 的路徑數量記為 \(\mathit{cnt_j}\)\(\mathit{sum_i}\) 其實就是 \(\sum \mathit{cnt_j}\)。這一步轉化其實就是換了個觀察對象,考慮的是每個顏色對 \(\mathit{sum_i}\) 的 貢獻。而 \(\mathit{cnt_j}\) 其實很好處理出來,只需要每遇到一個新顏色,就 \(\mathit{cnt_{col_u}}+=\mathit{size_u}\) 即可,其中 \(\mathit{size_u}\) 為 u 的子樹大小,意味着這個子樹裏的所有節點都在這個顏色上對 \(u\) 的答案有一個貢獻。

考慮到點分治過程中,我們只需要分別考慮統計:

  1. 子樹中以當前根節點為端點的路徑對根的貢獻
  2. lca 為當前根節點的路徑對子樹內每個點的貢獻

1 部分比較好辦,由於點分治中,遞歸層數不超過 \(\log{n}\),每一層我們都可以遍歷全部子樹,這個時候就可以使用 \(\mathit{sum_i}\) 的定義式來在遍歷子樹的過程中順便統計了。

而針對 2 部分,設當前根節點 \(u\) 的一個子節點為 \(d\),\(d\) 的子樹裏任取一個點為 \(v\),那麼 \(v\) 的答案可以分為兩部分:

  1. \((u, v)\) 路徑上出現過的顏色,數量設為 \(\mathit{num}\)\(u\) 除了 \(d\) 以外的其他所有子樹的總大小設為 \(\mathit{siz1}\), 那麼這些出現過的顏色對 \(v\) 的答案貢獻為 \(\mathit{num}\times \mathit{siz1}\)
  2. \((u, v)\) 路徑上沒有出現過的顏色 \(j\),它們的貢獻來自於 \(u\) 除了 \(d\) 以外的其他所有子樹的 \(\mathit{cnt_j}\),這部分答案為 \(\sum_{j \notin (u, v)} \mathit{cnt_j}\)

以上是全部統計思路,實現細節詳見參考代碼。

參考代碼
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for (int i = (a); i <= (b); ++i)
const int N = 200005;
int h[N], nxt[N * 2], to[N * 2], c[N], gr;
#define il inline

il void tu(int x, int y) { to[++gr] = y, nxt[gr] = h[x], h[x] = gr; }

typedef long long ll;
int n, nn, siz[N], mn, rt;
bool vis[N];

void get_root(int u, int f) {
  siz[u] = 1;
  int mx = 0;
  for (int i = h[u]; i; i = nxt[i]) {
    int d = to[i];
    if (vis[d] || d == f) continue;
    get_root(d, u);
    siz[u] += siz[d];
    mx = max(mx, siz[d]);
  }
  mx = max(mx, nn - siz[u]);
  if (mx < mn) mn = mx, rt = u;
}

ll ans[N], sum;
int cnt[N], v[N];
// sum实时统计的是cnt[i]的和
int nowrt;

void get_dis(int u, int f, int now) {  // now为当前树链上的颜色数量(不含u)
  siz[u] = 1;
  if (!v[c[u]]) {
    sum -= cnt[c[u]];  // 减去在之前子树中已经出现过的颜色信息
    now++;
  }
  v[c[u]]++;
  ans[u] += sum + now * siz[nowrt];  // 统计过u点的路径对u的贡献
  for (int i = h[u]; i; i = nxt[i]) {
    int d = to[i];
    if (d == f || vis[d]) continue;
    get_dis(d, u, now);
    siz[u] += siz[d];
  }
  v[c[u]]--;
  if (!v[c[u]]) {
    sum += cnt[c[u]];  // 回溯
  }
}

void get_cnt(int u, int f) {
  if (!v[c[u]]) {
    cnt[c[u]] += siz[u];
    sum += siz[u];  // 将刚遍历过的子树的信息整合到cnt[i]和sum上去
  }
  v[c[u]]++;
  for (int i = h[u]; i; i = nxt[i]) {
    int d = to[i];
    if (vis[d] || d == f) continue;
    get_cnt(d, u);
  }
  v[c[u]]--;
}

void clear(int u, int f, int now) {
  if (!v[c[u]]) now++;
  v[c[u]]++;
  ans[u] -= now;
  ans[nowrt] += now;
  for (int i = h[u]; i; i = nxt[i]) {
    int d = to[i];
    if (vis[d] || d == f) continue;
    clear(d, u, now);
  }
  v[c[u]]--;
  cnt[c[u]] = 0;
}

void clear2(int u, int f) {
  cnt[c[u]] = 0;
  for (int i = h[u]; i; i = nxt[i]) {
    int d = to[i];
    if (vis[d] || d == f) continue;
    clear2(d, u);
  }
}

int son[N];

void divid(int u) {
  vis[u] = 1;
  int tot = 0;
  nowrt = u;
  ans[u]++;
  for (int i = h[u]; i; i = nxt[i]) {
    if (vis[to[i]]) continue;
    son[++tot] = to[i];
  }
  siz[u] = sum = cnt[c[u]] = 1;
  v[c[u]]++;
  rep(i, 1, tot) {  // 统计每个子树和它之前的所有子树中节点组合产生的贡献
    int d = son[i];
    get_dis(d, u, 0);
    get_cnt(d, u);
    siz[u] += siz[d];
    cnt[c[u]] += siz[d];
    sum += siz[d];
  }
  clear2(u, 0);  // 清空数组,记得不可以用memset
  siz[u] = sum = cnt[c[u]] = 1;
  for (int i = tot; i >= 1;
       --i) {  // 统计每个子树和它之后的所有子树中节点组合产生的贡献
    int d = son[i];
    get_dis(d, u, 0);
    get_cnt(d, u);
    siz[u] += siz[d];
    cnt[c[u]] += siz[d];
    sum += siz[d];
  }
  v[c[u]]--;
  clear(u, 0, 0);                      // 清空的同时统计答案
  for (int i = h[u]; i; i = nxt[i]) {  // 继续向下进行点分治
    int d = to[i];
    if (vis[d]) continue;
    nn = siz[d], mn = n + 1, rt = 0;
    get_root(d, u);
    divid(rt);
  }
}

int main() {
  scanf("%d", &n);
  int u, v;
  rep(i, 1, n) scanf("%d", &c[i]);
  rep(i, 2, n) scanf("%d%d", &u, &v), tu(u, v), tu(v, u);
  rt = 0, nn = n, mn = n + 1;
  get_root(1, 0);
  divid(rt);
  rep(i, 1, n) printf("%lld\n", ans[i]);
  return 0;
}

邊分治

與上面的點分治類似,我們選取一條邊,把樹儘量均勻地分成兩部分(使邊連接的兩個子樹的 \(\mathit{size}\) 儘量接近)。然後遞歸處理左右子樹,統計信息。

但是這是不行的,考慮一個菊花圖:

菊花圖

我們發現當一個點下有多個 \(\mathit{size}\) 接近的兒子時,應用邊分治的時間複雜度是無法接受的。

如果這個圖是個二叉樹,就可以避免上面菊花圖中應用邊分治的弊端。因此我們考慮把一個多叉樹轉化成二叉樹。

顯然,我們只需像線段樹那樣建樹就可以了。就像這樣

建樹

新建出來的點根據題目要求給予恰當的信息即可。例如:統計路徑長度時,將原邊邊權賦為 \(1\), 將新建的邊邊權賦為 \(0\) 即可。

分析複雜度,發現最多會增加 \(O(n)\) 個點,則總複雜度為 \(O(n\log n)\)

幾乎所有點分治的題邊分都能做(常數上有差距,但是不卡),所以就不放例題了。

點分樹

點分樹是通過更改原樹形態使樹的層數變為穩定 \(\log n\) 的一種重構樹。

常用於解決與樹原形態無關的帶修改問題。

算法分析

我們通過點分治每次找重心的方式來對原樹進行重構。

將每次找到的重心與上一層的重心締結父子關係,這樣就可以形成一棵 \(\log n\) 層的樹。

由於樹是 \(\log n\) 層的,很多原來並不對勁的暴力在點分樹上均有正確的複雜度。

代碼實現

有一個小技巧:每次用遞歸上一層的總大小 \(\mathit{tot}\) 減去上一層的點的重兒子大小,得到的就是這一層的總大小。這樣求重心就只需一次 DFS 了。

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <bits/stdc++.h>

using namespace std;

typedef vector<int>::iterator IT;

struct Edge {
  int to, nxt, val;

  Edge() {}

  Edge(int to, int nxt, int val) : to(to), nxt(nxt), val(val) {}
} e[300010];

int head[150010], cnt;

void addedge(int u, int v, int val) {
  e[++cnt] = Edge(v, head[u], val);
  head[u] = cnt;
}

int siz[150010], son[150010];
bool vis[150010];

int tot, lasttot;
int maxp, root;

void getG(int now, int fa) {
  siz[now] = 1;
  son[now] = 0;
  for (int i = head[now]; i; i = e[i].nxt) {
    int vs = e[i].to;
    if (vs == fa || vis[vs]) continue;
    getG(vs, now);
    siz[now] += siz[vs];
    son[now] = max(son[now], siz[vs]);
  }
  son[now] = max(son[now], tot - siz[now]);
  if (son[now] < maxp) {
    maxp = son[now];
    root = now;
  }
}

struct Node {
  int fa;
  vector<int> anc;
  vector<int> child;
} nd[150010];

int build(int now, int ntot) {
  tot = ntot;
  maxp = 0x7f7f7f7f;
  getG(now, 0);
  int g = root;
  vis[g] = 1;
  for (int i = head[g]; i; i = e[i].nxt) {
    int vs = e[i].to;
    if (vis[vs]) continue;
    int tmp = build(vs, ntot - son[vs]);
    nd[tmp].fa = now;
    nd[now].child.push_back(tmp);
  }
  return g;
}

int virtroot;

int main() {
  int n;
  cin >> n;
  for (int i = 1; i < n; i++) {
    int u, v, val;
    cin >> u >> v >> val;
    addedge(u, v, val);
    addedge(v, u, val);
  }
  virtroot = build(1, n);
}