跳转至

樹哈希

判斷一些樹是否同構的時,我們常常把這些樹轉成哈希值儲存起來,以降低複雜度。

樹哈希是很靈活的,可以設計出各種各樣的哈希方式;但是如果隨意設計,很有可能是錯誤的,可能被卡。以下介紹一類容易實現且不易被卡的方法。

方法

這類方法需要一個多重集的哈希函數。以某個結點為根的子樹的哈希值,就是以它的所有兒子為根的子樹的哈希值構成的多重集的哈希值,即:

\[ h_x = f(\{ h_i \mid i \in son(x) \}) \]

其中 \(h_x\) 表示以 \(x\) 為根的子樹的哈希值,\(f\) 是多重集的哈希函數。

以代碼中使用的哈希函數為例:

\[ f(S) = \left( c + \sum_{x \in S} g(x) \right) \bmod m \]

其中 \(c\) 為常數,一般使用 \(1\) 即可。\(m\) 為模數,一般使用 \(2^{32}\)\(2^{64}\) 進行自然溢出,也可使用大素數。\(g\) 為整數到整數的映射,代碼中使用 xor shift,也可以選用其他的函數,但是不建議使用多項式。為了預防出題人對着 xor hash 卡,還可以在映射前後異或一個隨機常數。

這種哈希十分好寫。如果需要換根,第二次 DP 時只需把子樹哈希減掉即可。

例題

UOJ #763. 樹哈希

這是一道模板題。不用多説,以 \(1\) 為根跑一遍 DFS 就好了。

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <cctype>
#include <chrono>
#include <cstdio>
#include <random>
#include <set>
#include <vector>

typedef unsigned long long ull;

const ull mask = std::chrono::steady_clock::now().time_since_epoch().count();

ull shift(ull x) {
  x ^= mask;
  x ^= x << 13;
  x ^= x >> 7;
  x ^= x << 17;
  x ^= mask;
  return x;
}

const int N = 1e6 + 10;

int n;
ull hash[N];
std::vector<int> edge[N];
std::set<ull> trees;

void getHash(int x, int p) {
  hash[x] = 1;
  for (int i : edge[x]) {
    if (i == p) {
      continue;
    }
    getHash(i, x);
    hash[x] += shift(hash[i]);
  }
  trees.insert(hash[x]);
}

int main() {
  scanf("%d", &n);
  for (int i = 1; i < n; i++) {
    int u, v;
    scanf("%d%d", &u, &v);
    edge[u].push_back(v);
    edge[v].push_back(u);
  }
  getHash(1, 0);
  printf("%lu", trees.size());
}

[BJOI2015] 樹的同構

這道題所説的同構是指無根樹的,而上面所介紹的方法是針對有根樹的。因此只有當根一樣時,同構的兩棵無根樹哈希值才相同。由於數據範圍較小,我們可以暴力求出以每個點為根時的哈希值,排序後比較。

如果數據範圍較大,我們也可以使用換根 DP,遍歷樹兩遍,求出以每個點為根時的哈希值。我們還可以利用上面的多重集哈希函數:把以每個結點為根時的哈希值都存進多重集,再把多重集的哈希值算出來,進行比較(做法一)。

還可以通過找重心的方式來優化複雜度。一棵樹的重心最多隻有兩個,只需把以它(們)為根時的哈希值求出來即可。接下來,既可以分別比較這些哈希值(做法二),也可以在有一個重心時取它的哈希值作為整棵樹的哈希值,有兩個時則取其中較小(大)的。

做法一
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <chrono>
#include <cstdio>
#include <map>
#include <random>
#include <set>
#include <vector>

typedef unsigned long long ull;

const int N = 60, M = 998244353;
const ull mask = std::chrono::steady_clock::now().time_since_epoch().count();

ull shift(ull x) {
  x ^= mask;
  x ^= x << 13;
  x ^= x >> 7;
  x ^= x << 17;
  x ^= mask;
  return x;
}

std::vector<int> edge[N];
ull sub[N], root[N];
std::map<ull, int> trees;

void getSub(int x) {
  sub[x] = 1;
  for (int i : edge[x]) {
    getSub(i);
    sub[x] += shift(sub[i]);
  }
}

void getRoot(int x) {
  for (int i : edge[x]) {
    root[i] = sub[i] + shift(root[x] - shift(sub[i]));
    getRoot(i);
  }
}

int main() {
  int m;
  scanf("%d", &m);
  for (int t = 1; t <= m; t++) {
    int n, rt = 0;
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
      int fa;
      scanf("%d", &fa);
      if (fa) {
        edge[fa].push_back(i);
      } else {
        rt = i;
      }
    }
    getSub(rt);
    root[rt] = sub[rt];
    getRoot(rt);
    ull hash = 1;
    for (int i = 1; i <= n; i++) {
      hash += shift(root[i]);
    }
    if (!trees.count(hash)) {
      trees[hash] = t;
    }
    printf("%d\n", trees[hash]);
    for (int i = 1; i <= n; i++) {
      edge[i].clear();
    }
  }
}
做法二
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <chrono>
#include <cstdio>
#include <map>
#include <random>
#include <set>
#include <vector>

typedef unsigned long long ull;
typedef std::pair<ull, ull> Hash2;

const int N = 60, M = 998244353;
const ull mask = std::chrono::steady_clock::now().time_since_epoch().count();

ull shift(ull x) {
  x ^= mask;
  x ^= x << 13;
  x ^= x >> 7;
  x ^= x << 17;
  x ^= mask;
  return x;
}

int n;
int size[N], weight[N], centroid[2];
std::vector<int> edge[N];
std::map<Hash2, int> trees;

void getCentroid(int x, int fa) {
  size[x] = 1;
  weight[x] = 0;
  for (int i : edge[x]) {
    if (i == fa) {
      continue;
    }
    getCentroid(i, x);
    size[x] += size[i];
    weight[x] = std::max(weight[x], size[i]);
  }
  weight[x] = std::max(weight[x], n - size[x]);
  if (weight[x] <= n / 2) {
    int index = centroid[0] != 0;
    centroid[index] = x;
  }
}

ull getHash(int x, int fa) {
  ull hash = 1;
  for (int i : edge[x]) {
    if (i == fa) {
      continue;
    }
    hash += shift(getHash(i, x));
  }
  return hash;
}

int main() {
  int m;
  scanf("%d", &m);
  for (int t = 1; t <= m; t++) {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
      int fa;
      scanf("%d", &fa);
      if (fa) {
        edge[fa].push_back(i);
        edge[i].push_back(fa);
      }
    }
    getCentroid(1, 0);
    Hash2 hash;
    hash.first = getHash(centroid[0], 0);
    if (centroid[1]) {
      hash.second = getHash(centroid[1], 0);
      if (hash.first > hash.second) {
        std::swap(hash.first, hash.second);
      }
    } else {
      hash.second = hash.first;
    }
    if (!trees.count(hash)) {
      trees[hash] = t;
    }
    printf("%d\n", trees[hash]);
    for (int i = 1; i <= n; i++) {
      edge[i].clear();
    }
    centroid[0] = centroid[1] = 0;
  }
}

HDU 6647 Bracket Sequences on Tree

題目要求遍歷一棵無根樹產生的本質不同括號序列方案數。

首先可以注意到,兩棵不同構的有根樹一定不會生成相同的括號序列。我們先考慮遍歷有根樹能夠產生的本質不同括號序列方案數,假設我們當前考慮的子樹根節點為 \(u\),記 \(f(u)\) 表示這棵子樹的方案數,從 \(u\) 開始往下遍歷,順序可以隨意選擇,產生 \(|son(u)|!\) 種排列,遍歷每個兒子節點 \(v\)\(v\) 的子樹內有 \(f(v)\) 種方案,因此有 \(f(u)=|son(u)|! \cdot \prod_{v \in son(u)} f(v)\)。但是,同構的子樹之間會產生重複,\(f(u)\) 需要除掉每種本質不同子樹出現次數階乘的乘積,類似於多重集合的排列。

通過上述 DP,可以求出根節點的方案數。再通過換根 DP,將父親節點的哈希值和方案信息轉移給兒子,可以求出以每個節點為根時的哈希值和方案數。每種不同的子樹只需要計數一次即可。

參考代碼
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <chrono>
#include <cstdio>
#include <map>
#include <random>
#include <set>
#include <vector>

typedef unsigned long long ull;

const int N = 1e5 + 10, M = 998244353;
const ull mask = std::chrono::steady_clock::now().time_since_epoch().count();

struct Tree {
  ull hash, deg, ans;
  std::map<ull, ull> son;

  Tree() { clear(); }

  void add(Tree& o);
  void remove(Tree& o);
  void clear();
};

ull inv(ull x) {
  ull y = M - 2, z = 1;
  while (y) {
    if (y & 1) {
      z = z * x % M;
    }
    x = x * x % M;
    y >>= 1;
  }
  return z;
}

ull shift(ull x) {
  x ^= mask;
  x ^= x << 13;
  x ^= x >> 7;
  x ^= x << 17;
  x ^= mask;
  return x;
}

void Tree::add(Tree& o) {
  ull temp = shift(o.hash);
  hash += temp;
  ans = ans * ++deg % M * inv(++son[temp]) % M * o.ans % M;
}

void Tree::remove(Tree& o) {
  ull temp = shift(o.hash);
  hash -= temp;
  ans = ans * inv(deg--) % M * son[temp]-- % M * inv(o.ans) % M;
}

void Tree::clear() {
  hash = 1;
  deg = 0;
  ans = 1;
  son.clear();
}

std::vector<int> edge[N];
Tree sub[N], root[N];
std::map<ull, ull> trees;

void getSub(int x, int fa) {
  for (int i : edge[x]) {
    if (i == fa) {
      continue;
    }
    getSub(i, x);
    sub[x].add(sub[i]);
  }
}

void getRoot(int x, int fa) {
  for (int i : edge[x]) {
    if (i == fa) {
      continue;
    }
    root[x].remove(sub[i]);
    root[i] = sub[i];
    root[i].add(root[x]);
    root[x].add(sub[i]);
    getRoot(i, x);
  }
  trees[root[x].hash] = root[x].ans;
}

int main() {
  int t, n;
  scanf("%d", &t);
  while (t--) {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
      int u, v;
      scanf("%d%d", &u, &v);
      edge[u].push_back(v);
      edge[v].push_back(u);
    }
    getSub(1, 0);
    root[1] = sub[1];
    getRoot(1, 0);
    ull tot = 0;
    for (auto p : trees) {
      tot = (tot + p.second) % M;
    }
    printf("%lld\n", tot);
    for (int i = 1; i <= n; i++) {
      edge[i].clear();
      sub[i].clear();
      root[i].clear();
    }
    trees.clear();
  }
}

參考資料

文中的哈希方法參考並拓展自博客 一種好寫且卡不掉的樹哈希