跳转至

K-D Tree

k-D Tree(KDT , k-Dimension Tree) 是一種可以 高效處理 \(k\) 維空間信息 的數據結構。

在結點數 \(n\) 遠大於 \(2^k\) 時,應用 k-D Tree 的時間效率很好。

在算法競賽的題目中,一般有 \(k=2\)。在本頁面分析時間複雜度時,將認為 \(k\) 是常數。

建樹

k-D Tree 具有二叉搜索樹的形態,二叉搜索樹上的每個結點都對應 \(k\) 維空間內的一個點。其每個子樹中的點都在一個 \(k\) 維的超長方體內,這個超長方體內的所有點也都在這個子樹中。

假設我們已經知道了 \(k\) 維空間內的 \(n\) 個不同的點的座標,要將其構建成一棵 k-D Tree,步驟如下:

  1. 若當前超長方體中只有一個點,返回這個點。

  2. 選擇一個維度,將當前超長方體按照這個維度分成兩個超長方體。

  3. 選擇切割點:在選擇的維度上選擇一個點,這一維度上的值小於這個點的歸入一個超長方體(左子樹),其餘的歸入另一個超長方體(右子樹)。

  4. 將選擇的點作為這棵子樹的根節點,遞歸對分出的兩個超長方體構建左右子樹,維護子樹的信息。

為了方便理解,我們舉一個 \(k=2\) 時的例子。

其構建出 k-D Tree 的形態可能是這樣的:

其中樹上每個結點上的座標是選擇的分割點的座標,非葉子結點旁的 \(x\)\(y\) 是選擇的切割維度。

這樣的複雜度無法保證。對於 \(2,3\) 兩步,我們提出兩個優化:

  1. 輪流選擇 \(k\) 個維度,以保證在任意連續 \(k\) 層裏每個維度都被切割到。
  2. 每次在維度上選擇切割點時選擇該維度上的 中位數,這樣可以保證每次分成的左右子樹大小盡量相等。

可以發現,使用優化 \(2\) 後,構建出的 k-D Tree 的樹高最多為 \(\log n+O(1)\)

現在,構建 k-D Tree 時間複雜度的瓶頸在於快速選出一個維度上的中位數,並將在該維度上的值小於該中位數的置於中位數的左邊,其餘置於右邊。如果每次都使用 sort 函數對該維度進行排序,時間複雜度是 \(O(n\log^2 n)\) 的。事實上,單次找出 \(n\) 個元素中的中位數並將中位數置於排序後正確的位置的複雜度可以達到 \(O(n)\)

我們來回顧一下快速排序的思想。每次我們選出一個數,將小於該數的置於該數的左邊,大於該數的置於該數的右邊,保證該數在排好序後正確的位置上,然後遞歸排序左側和右側的值。這樣的期望複雜度是 \(O(n\log n)\) 的。但是由於 k-D Tree 只要求要中位數在排序後正確的位置上,所以我們只需要遞歸排序包含中位數的 一側。可以證明,這樣的期望複雜度是 \(O(n)\) 的。在 algorithm 庫中,有一個實現相同功能的函數 nth_element(),要找到 s[l]s[r] 之間的值按照排序規則 cmp 排序後在 s[mid] 位置上的值,並保證 s[mid] 左邊的值小於 s[mid],右邊的值大於 s[mid],只需寫 nth_element(s+l,s+mid,s+r+1,cmp)

藉助這種思想,構建 k-D Tree 時間複雜度是 \(O(n\log n)\) 的。

高維空間上的操作

在查詢高維矩形區域內的所有點的一些信息時,記錄每個結點子樹內每一維度上的座標的最大值和最小值。如果當前子樹對應的矩形與所求矩形沒有交點,則不繼續搜索其子樹;如果當前子樹對應的矩形完全包含在所求矩形內,返回當前子樹內所有點的權值和;否則,判斷當前點是否在所求矩形內,更新答案並遞歸在左右子樹中查找答案。

實現
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
int query(int p) {
  if (!p) return 0;
  bool flag{false};
  for (int k : {0, 1}) flag |= (!(l.x[k] <= t[p].L[k] && t[p].R[k] <= h.x[k]));
  if (!flag) return t[p].sum;
  for (int k : {0, 1})
    if (t[p].R[k] < l.x[k] || h.x[k] < t[p].L[k]) return 0;
  int ans{0};
  flag = false;
  for (int k : {0, 1}) flag |= (!(l.x[k] <= t[p].x[k] && t[p].x[k] <= h.x[k]));
  if (!flag) ans = t[p].v;
  return ans += query(t[p].l) + query(t[p].r);
}

複雜度分析

先考慮二維的,在查詢矩形 \(R\) 時,我們將 k-D Tree 上的結點分為三類:

  1. \(R\) 無交。
  2. 完全被 \(R\) 包含。
  3. 部分被 \(R\) 包含。

顯然單次查詢的複雜度是第 3 類點的個數。注意到第三類點的矩形要麼完全包含 \(R\),要麼互不包含,而前者顯然只有 \(O(h)=O(\log n)\) 個,現在我們來分析後者的個數。

首先,我們不妨令矩形的所有邊偏移 \(\epsilon\),使得查詢矩形不穿過已經有的任何點。這樣顯然是不影響矩形的查詢所涵蓋的點集的。

注意到互不包含的第 3 類點所對應的矩形,一定有 \(R\) 的一條邊穿過之。所以我們只需要計算 \(R\) 的每條邊穿過的矩形個數,即任意一條線段最多經過多少個點對應的矩形。

考慮對於某一個結點 \(u\),它有四個孫子,且它到每一個孫子都在兩個維度上各進行了一次劃分。經過觀察可以發現,按照這種方法將一個矩形劃分成四個子矩形,一條與座標軸平行的線段最多經過兩個區域,即從 \(u\) 出發的查詢,最多向下進入兩個孫子仍有第 3 類點(如果線段剛好與分割邊界重合則不一定,但是我們偏移查詢矩形邊界的操作使得這種情況不存在)。

而因為建樹的時候,每個點是其整個子樹在當前劃分維度上的中位數,所以子樹大小必定減半。於是,設 \(u\) 的子樹大小為 \(n\),我們能寫出如下遞歸式:

\[ T(n)=2T(n/4)+O(1) \]

由主定理得 \(T(n)=O(\sqrt{n})\)

將遞歸式推廣到 \(k\) 維,即 \(T(n)=2^{k-1}T(n/2^k)+O(1)\),於是 \(T(n)=O(n^{1-\frac1k})\)(將 \(k\) 視為常數)。

插入/刪除

如果維護的這個 \(k\) 維點集是可變的,即可能會插入或刪除一些點,此時 k-D Tree 的平衡性無法保證。由於 k-D Tree 的構造,不能支持旋轉,類似與 FHQ Treap 的隨機優先級也不能保證其複雜度。對此,有兩種比較常見的維護方法。

Note

很多選手會使用替罪羊樹結構來維護。但是注意到在剛才的複雜度分析中,要求兒子的子樹大小嚴格減半,即樹高必須為嚴格的 \(\log n+O(1)\),而替罪羊樹只滿足樹高 \(O(\log n)\),故查詢複雜度無法保證。

根號重構

插入的時候,先存下來要插入的點,每 \(B\) 次插入進行一次重構。

刪除打個標記即可。如果要求較為嚴格,可以維護樹內有多少個被刪除了,達到 \(B\) 則重構。

修改複雜度均攤 \(O(n\log n/B)\),查詢 \(O(B+n^{1-\frac1k})\),若二者數量同階則 \(B=O(\sqrt{n\log n})\) 最優(修改 \(O(\sqrt{n\log n})\),查詢 \(O(\sqrt{n\log n}+n^{1-\frac1k})\))。

二進制分組

考慮維護若干棵 \(2\) 的自然數次冪的 k-D Tree,滿足這些樹的大小之和為 \(n\)

插入的時候,新增一棵大小為 \(1\) 的 k-D Tree,然後不斷將相同大小的樹合併(直接拍扁重構)。實現的時候可以只重構一次。

容易發現需要合併的樹的大小一定從 \(2^0\) 開始且指數連續。複雜度類似二進制加法,是均攤 \(O(n\log^2 n)\) 的,因為重構本身帶 \(\log\)

查詢的時候,直接分別在每顆樹上查詢,複雜度為 \(O\left(\sum_{i\geq0} (\frac n{2^i})^{1-\frac1k}\right)=O(n^{1-\frac1k})\)

例題

洛谷 P4148 簡單題

在一個初始值全為 \(0\)\(n\times n\) 的二維矩陣上,進行 \(q\) 次操作,每次操作為以下兩種之一:

  1. 1 x y A:將座標 \((x,y)\) 上的數加上 \(A\)
  2. 2 x1 y1 x2 y2:輸出以 \((x1,y1)\) 為左下角,\((x2,y2)\) 為右上角的矩形內(包括矩形邊界)的數字和。

強制在線。內存限制 20M。保證答案及所有過程量在 int 範圍內。

\(1\le n\le 500000 , 1\le q\le 200000\)

20M 的空間卡掉了所有樹套樹,強制在線卡掉了 CDQ 分治,只能使用 k-D Tree。

以下是二進制分組的參考代碼。

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <bits/stdc++.h>
using namespace std;
constexpr int N(2e5), LG{18};

struct pt {
  int x[2];
  int v, sum;
  int l, r;
  int L[2], R[2];
} t[N + 5], l, h;

int rt[LG];
int b[N + 5], cnt;

void upd(int p) {
  t[p].sum = t[t[p].l].sum + t[t[p].r].sum + t[p].v;
  for (int k : {0, 1}) {
    t[p].L[k] = t[p].R[k] = t[p].x[k];
    if (t[p].l) {
      t[p].L[k] = min(t[p].L[k], t[t[p].l].L[k]);
      t[p].R[k] = max(t[p].R[k], t[t[p].l].R[k]);
    }
    if (t[p].r) {
      t[p].L[k] = min(t[p].L[k], t[t[p].r].L[k]);
      t[p].R[k] = max(t[p].R[k], t[t[p].r].R[k]);
    }
  }
}

int build(int l, int r, int dep = 0) {
  int p{l + r >> 1};
  nth_element(b + l, b + p, b + r + 1,
              [dep](int x, int y) { return t[x].x[dep] < t[y].x[dep]; });
  int x{b[p]};
  if (l < p) t[x].l = build(l, p - 1, dep ^ 1);
  if (p < r) t[x].r = build(p + 1, r, dep ^ 1);
  upd(x);
  return x;
}

void append(int &p) {
  if (!p) return;
  b[++cnt] = p;
  append(t[p].l);
  append(t[p].r);
  p = 0;
}

int query(int p) {
  if (!p) return 0;
  bool flag{false};
  for (int k : {0, 1}) flag |= (!(l.x[k] <= t[p].L[k] && t[p].R[k] <= h.x[k]));
  if (!flag) return t[p].sum;
  for (int k : {0, 1})
    if (t[p].R[k] < l.x[k] || h.x[k] < t[p].L[k]) return 0;
  int ans{0};
  flag = false;
  for (int k : {0, 1}) flag |= (!(l.x[k] <= t[p].x[k] && t[p].x[k] <= h.x[k]));
  if (!flag) ans = t[p].v;
  return ans += query(t[p].l) + query(t[p].r);
}

int main() {
  int n;
  cin >> n;
  int lst{0};
  n = 0;
  while (true) {
    int op;
    cin >> op;
    if (op == 1) {
      int x, y, A;
      cin >> x >> y >> A;
      x ^= lst;
      y ^= lst;
      A ^= lst;
      t[++n] = {{x, y}, A};
      b[cnt = 1] = n;
      for (int sz{0};; ++sz)
        if (!rt[sz]) {
          rt[sz] = build(1, cnt);
          break;
        } else
          append(rt[sz]);
    } else if (op == 2) {
      cin >> l.x[0] >> l.x[1] >> h.x[0] >> h.x[1];
      l.x[0] ^= lst;
      l.x[1] ^= lst;
      h.x[0] ^= lst;
      h.x[1] ^= lst;
      lst = 0;
      for (int i{0}; i < LG; ++i) lst += query(rt[i]);
      cout << lst << "\n";
    } else
      break;
  }
  return 0;
}

鄰域查詢

Warning

使用 k-D Tree 單次查詢最近點的時間複雜度最壞還是 \(O(n)\) 的,但不失為一種優秀的騙分算法,使用時請注意。在這裏對鄰域查詢的講解僅限於加強對 k-D Tree 結構的認識。

例題 luogu P1429 平面最近點對(加強版)

給定平面上的 \(n\) 個點 \((x_i,y_i)\),找出平面上最近兩個點對之間的 歐幾里得距離

\(2\le n\le 200000 , 0\le x_i,y_i\le 10^9\)

首先建出關於這 \(n\) 個點的 2-D Tree。

枚舉每個結點,對於每個結點找到不等於該結點且距離最小的點,即可求出答案。每次暴力遍歷 2-D Tree 上的每個結點的時間複雜度是 \(O(n)\) 的,需要剪枝。我們可以維護一個子樹中的所有結點在每一維上的座標的最小值和最大值。假設當前已經找到的最近點對的距離是 \(ans\),如果查詢點到子樹內所有點都包含在內的長方形的 最近 距離大於等於 \(ans\),則在這個子樹內一定沒有答案,搜索時不進入這個子樹。

此外,還可以使用一種啓發式搜索的方法,即若一個結點的兩個子樹都有可能包含答案,先在與查詢點距離最近的一個子樹中搜索答案。可以認為,查詢點到子樹對應的長方形的最近距離就是此題的估價函數

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
using namespace std;
const int maxn = 200010;
int n, d[maxn], lc[maxn], rc[maxn];
double ans = 2e18;

struct node {
  double x, y;
} s[maxn];

double L[maxn], R[maxn], D[maxn], U[maxn];

double dist(int a, int b) {
  return (s[a].x - s[b].x) * (s[a].x - s[b].x) +
         (s[a].y - s[b].y) * (s[a].y - s[b].y);
}

bool cmp1(node a, node b) { return a.x < b.x; }

bool cmp2(node a, node b) { return a.y < b.y; }

void maintain(int x) {
  L[x] = R[x] = s[x].x;
  D[x] = U[x] = s[x].y;
  if (lc[x])
    L[x] = min(L[x], L[lc[x]]), R[x] = max(R[x], R[lc[x]]),
    D[x] = min(D[x], D[lc[x]]), U[x] = max(U[x], U[lc[x]]);
  if (rc[x])
    L[x] = min(L[x], L[rc[x]]), R[x] = max(R[x], R[rc[x]]),
    D[x] = min(D[x], D[rc[x]]), U[x] = max(U[x], U[rc[x]]);
}

int build(int l, int r) {
  if (l > r) return 0;
  if (l == r) {
    maintain(l);
    return l;
  }
  int mid = (l + r) >> 1;
  double avx = 0, avy = 0, vax = 0, vay = 0;  // average variance
  for (int i = l; i <= r; i++) avx += s[i].x, avy += s[i].y;
  avx /= (double)(r - l + 1);
  avy /= (double)(r - l + 1);
  for (int i = l; i <= r; i++)
    vax += (s[i].x - avx) * (s[i].x - avx),
        vay += (s[i].y - avy) * (s[i].y - avy);
  if (vax >= vay)
    d[mid] = 1, nth_element(s + l, s + mid, s + r + 1, cmp1);
  else
    d[mid] = 2, nth_element(s + l, s + mid, s + r + 1, cmp2);
  lc[mid] = build(l, mid - 1), rc[mid] = build(mid + 1, r);
  maintain(mid);
  return mid;
}

double f(int a, int b) {
  double ret = 0;
  if (L[b] > s[a].x) ret += (L[b] - s[a].x) * (L[b] - s[a].x);
  if (R[b] < s[a].x) ret += (s[a].x - R[b]) * (s[a].x - R[b]);
  if (D[b] > s[a].y) ret += (D[b] - s[a].y) * (D[b] - s[a].y);
  if (U[b] < s[a].y) ret += (s[a].y - U[b]) * (s[a].y - U[b]);
  return ret;
}

void query(int l, int r, int x) {
  if (l > r) return;
  int mid = (l + r) >> 1;
  if (mid != x) ans = min(ans, dist(x, mid));
  if (l == r) return;
  double distl = f(x, lc[mid]), distr = f(x, rc[mid]);
  if (distl < ans && distr < ans) {
    if (distl < distr) {
      query(l, mid - 1, x);
      if (distr < ans) query(mid + 1, r, x);
    } else {
      query(mid + 1, r, x);
      if (distl < ans) query(l, mid - 1, x);
    }
  } else {
    if (distl < ans) query(l, mid - 1, x);
    if (distr < ans) query(mid + 1, r, x);
  }
}

int main() {
  scanf("%d", &n);
  for (int i = 1; i <= n; i++) scanf("%lf%lf", &s[i].x, &s[i].y);
  build(1, n);
  for (int i = 1; i <= n; i++) query(1, n, i);
  printf("%.4lf\n", sqrt(ans));
  return 0;
}
例題 「CQOI2016」K 遠點對

給定平面上的 \(n\) 個點 \((x_i,y_i)\),求歐幾里得距離下的第 \(k\) 遠無序點對之間的距離。

\(n\le 100000 , 1\le k\le 100 , 0\le x_i,y_i<2^{31}\)

和上一道例題類似,從最近點對變成了 \(k\) 遠點對,估價函數改成了查詢點到子樹對應的長方形區域的最遠距離。用一個小根堆來維護當前找到的前 \(k\) 遠點對之間的距離,如果當前找到的點對距離大於堆頂,則彈出堆頂並插入這個距離,同樣的,使用堆頂的距離來剪枝。

由於題目中強調的是無序點對,即交換前後兩點的順序後仍是相同的點對,則每個有序點對會被計算兩次,那麼讀入的 \(k\) 要乘以 \(2\)

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <algorithm>
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;
const int maxn = 100010;
long long n, k;
priority_queue<long long, vector<long long>, greater<long long> > q;

struct node {
  long long x, y;
} s[maxn];

bool cmp1(node a, node b) { return a.x < b.x; }

bool cmp2(node a, node b) { return a.y < b.y; }

long long lc[maxn], rc[maxn], L[maxn], R[maxn], D[maxn], U[maxn];

void maintain(int x) {
  L[x] = R[x] = s[x].x;
  D[x] = U[x] = s[x].y;
  if (lc[x])
    L[x] = min(L[x], L[lc[x]]), R[x] = max(R[x], R[lc[x]]),
    D[x] = min(D[x], D[lc[x]]), U[x] = max(U[x], U[lc[x]]);
  if (rc[x])
    L[x] = min(L[x], L[rc[x]]), R[x] = max(R[x], R[rc[x]]),
    D[x] = min(D[x], D[rc[x]]), U[x] = max(U[x], U[rc[x]]);
}

int build(int l, int r) {
  if (l > r) return 0;
  int mid = (l + r) >> 1;
  double av1 = 0, av2 = 0, va1 = 0, va2 = 0;  // average variance
  for (int i = l; i <= r; i++) av1 += s[i].x, av2 += s[i].y;
  av1 /= (r - l + 1);
  av2 /= (r - l + 1);
  for (int i = l; i <= r; i++)
    va1 += (av1 - s[i].x) * (av1 - s[i].x),
        va2 += (av2 - s[i].y) * (av2 - s[i].y);
  if (va1 > va2)
    nth_element(s + l, s + mid, s + r + 1, cmp1);
  else
    nth_element(s + l, s + mid, s + r + 1, cmp2);
  lc[mid] = build(l, mid - 1);
  rc[mid] = build(mid + 1, r);
  maintain(mid);
  return mid;
}

long long sq(long long x) { return x * x; }

long long dist(int a, int b) {
  return max(sq(s[a].x - L[b]), sq(s[a].x - R[b])) +
         max(sq(s[a].y - D[b]), sq(s[a].y - U[b]));
}

void query(int l, int r, int x) {
  if (l > r) return;
  int mid = (l + r) >> 1;
  long long t = sq(s[mid].x - s[x].x) + sq(s[mid].y - s[x].y);
  if (t > q.top()) q.pop(), q.push(t);
  long long distl = dist(x, lc[mid]), distr = dist(x, rc[mid]);
  if (distl > q.top() && distr > q.top()) {
    if (distl > distr) {
      query(l, mid - 1, x);
      if (distr > q.top()) query(mid + 1, r, x);
    } else {
      query(mid + 1, r, x);
      if (distl > q.top()) query(l, mid - 1, x);
    }
  } else {
    if (distl > q.top()) query(l, mid - 1, x);
    if (distr > q.top()) query(mid + 1, r, x);
  }
}

int main() {
  cin >> n >> k;
  k *= 2;
  for (int i = 1; i <= k; i++) q.push(0);
  for (int i = 1; i <= n; i++) cin >> s[i].x >> s[i].y;
  build(1, n);
  for (int i = 1; i <= n; i++) query(1, n, i);
  cout << q.top() << endl;
  return 0;
}

習題

「SDOI2010」捉迷藏

「Violet」天使玩偶/SJY 擺棋子

「國家集訓隊」JZPFAR

「BOI2007」Mokia 摩基亞

luogu P4475 巧克力王國

「CH 弱省胡策 R2」TATT