跳转至

AC 自動機

AC 自動機是 以 Trie 的結構為基礎,結合 KMP 的思想 建立的自動機,用於解決多模式匹配等任務。

引入

很多人在第一次看到這個東西的時侯是非常興奮的。不過這個自動機叫作 Automaton,不是 Automation,這裏的 AC 也不是 Accepted,而是 Aho–Corasick(Alfred V. Aho, Margaret J. Corasick. 1975),讓萌新失望啦。切入正題。似乎在初學自動機相關的內容時,許多人難以建立對自動機的初步印象,尤其是在自學的時侯。而這篇文章就是為你們打造的。筆者在自學 AC 自動機後花費兩天時間製作若干的 gif,呈現出一個相對直觀的自動機形態。儘管這個圖似乎不太可讀,但這絕對是在作者自學的時侯,畫得最認真的 gif 了。另外有些小夥伴問這個 gif 拿什麼畫的。筆者用 Windows 畫圖軟件製作。

解釋

簡單來説,建立一個 AC 自動機有兩個步驟:

  1. 基礎的 Trie 結構:將所有的模式串構成一棵 Trie。
  2. KMP 的思想:對 Trie 樹上所有的結點構造失配指針。

然後就可以利用它進行多模式匹配了。

字典樹構建

AC 自動機在初始時會將若干個模式串丟到一個 Trie 裏,然後在 Trie 上建立 AC 自動機。這個 Trie 就是普通的 Trie,該怎麼建怎麼建。

這裏需要仔細解釋一下 Trie 的結點的含義,儘管這很小兒科,但在之後的理解中極其重要。Trie 中的結點表示的是某個模式串的前綴。我們在後文也將其稱作狀態。一個結點表示一個狀態,Trie 的邊就是狀態的轉移。

形式化地説,對於若干個模式串 \(s_1,s_2\dots s_n\),將它們構建一棵字典樹後的所有狀態的集合記作 \(Q\)

失配指針

AC 自動機利用一個 fail 指針來輔助多模式串的匹配。

狀態 \(u\) 的 fail 指針指向另一個狀態 \(v\),其中 \(v\in Q\),且 \(v\)\(u\) 的最長後綴(即在若干個後綴狀態中取最長的一個作為 fail 指針)。這裏簡單對比一下這裏的 fail 指針與 KMP 中的 next 指針:

  1. 共同點:兩者同樣是在失配的時候用於跳轉的指針。
  2. 不同點:next 指針求的是最長 Border(即最長的相同前後綴),而 fail 指針指向所有模式串的前綴中匹配當前狀態的最長後綴。

因為 KMP 只對一個模式串做匹配,而 AC 自動機要對多個模式串做匹配。有可能 fail 指針指向的結點對應着另一個模式串,兩者前綴不同。

沒看懂上面的對比不要急,你只需要知道,AC 自動機的失配指針指向當前狀態的最長後綴狀態即可。

AC 自動機在做匹配時,同一位上可匹配多個模式串。

構建指針

下面介紹構建 fail 指針的 基礎思想:(強調!基礎思想!基礎!)

構建 fail 指針,可以參考 KMP 中構造 Next 指針的思想。

考慮字典樹中當前的結點 \(u\)\(u\) 的父結點是 \(p\)\(p\) 通過字符 c 的邊指向 \(u\),即 \(trie[p,\mathtt{c}]=u\)。假設深度小於 \(u\) 的所有結點的 fail 指針都已求得。

  1. 如果 \(\text{trie}[\text{fail}[p],\mathtt{c}]\) 存在:則讓 u 的 fail 指針指向 \(\text{trie}[\text{fail}[p],\mathtt{c}]\)。相當於在 \(p\)\(\text{fail}[p]\) 後面加一個字符 c,分別對應 \(u\)\(fail[u]\)
  2. 如果 \(\text{trie}[\text{fail}[p],\mathtt{c}]\) 不存在:那麼我們繼續找到 \(\text{trie}[\text{fail}[\text{fail}[p]],\mathtt{c}]\)。重複 1 的判斷過程,一直跳 fail 指針直到根結點。
  3. 如果真的沒有,就讓 fail 指針指向根結點。

如此即完成了 \(\text{fail}[u]\) 的構建。

例子

下面放一張 GIF 幫助大家理解。對字符串 i he his she hers 組成的字典樹構建 fail 指針:

  1. 黃色結點:當前的結點 \(u\)
  2. 綠色結點:表示已經 BFS 遍歷完畢的結點,
  3. 橙色的邊:fail 指針。
  4. 紅色的邊:當前求出的 fail 指針。

AC_automation_gif_b_3.gif

我們重點分析結點 6 的 fail 指針構建:

AC_automation_6_9.png

找到 6 的父結點 5,\(\text{fail}[5]=10\)。然而 10 結點沒有字母 s 連出的邊;繼續跳到 10 的 fail 指針,\(\text{fail}[10]=0\)。發現 0 結點有字母 s 連出的邊,指向 7 結點;所以 \(\text{fail}[6]=7\)。最後放一張建出來的圖:

finish

字典樹與字典圖

我們直接上代碼吧。字典樹插入的代碼就不分析了(後面完整代碼裏有),先來看構建函數 build(),該函數的目標有兩個,一個是構建 fail 指針,一個是構建自動機。參數如下:

  1. tr[u,c]:有兩種理解方式。我們可以簡單理解為字典樹上的一條邊,即 \(\text{trie}[u,c]\);也可以理解為從狀態(結點)\(u\) 後加一個字符 c 到達的狀態(結點),即一個狀態轉移函數 \(\text{trans}(u,c)\)。下文中我們將用第二種理解方式繼續講解。
  2. 隊列 q:用於 BFS 遍歷字典樹。
  3. fail[u]:結點 \(u\) 的 fail 指針。
實現
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
void build() {
  for (int i = 0; i < 26; i++)
    if (tr[0][i]) q.push(tr[0][i]);
  while (q.size()) {
    int u = q.front();
    q.pop();
    for (int i = 0; i < 26; i++) {
      if (tr[u][i])
        fail[tr[u][i]] = tr[fail[u]][i], q.push(tr[u][i]);
      else
        tr[u][i] = tr[fail[u]][i];
    }
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def build():
    for i in range(0, 26):
        if tr[0][i] == 1:
            q.append(tr[0][i])
    while len(q) > 0:
        u = q[0]
        q.pop()
        for i in range(0, 26):
            if tr[u][i] == 1:
                fail[tr[u][i]] = tr[fail[u]][i]
                q.append(tr[u][i])
            else:
                tr[u][i] = tr[fail[u]][i]

解釋

解釋一下上面的代碼:build 函數將結點按 BFS 順序入隊,依次求 fail 指針。這裏的字典樹根結點為 0,我們將根結點的子結點一一入隊。若將根結點入隊,則在第一次 BFS 的時候,會將根結點兒子的 fail 指針標記為本身。因此我們將根結點的兒子一一入隊,而不是將根結點入隊。

然後開始 BFS:每次取出隊首的結點 u(\(\text{fail}[u]\) 在之前的 BFS 過程中已求得),然後遍歷字符集(這裏是 0-25,對應 a-z,即 \(u\) 的各個子節點):

  1. 如果 \(\text{trans}[u][\mathtt{i}]\) 存在,我們就將 \(\text{trans}[u][\mathtt{i}]\) 的 fail 指針賦值為 \(\text{trans}[\text{fail}[u]][\mathtt{i}]\)。這裏似乎有一個問題。根據之前的講解,我們應該用 while 循環,不停的跳 fail 指針,判斷是否存在字符 i 對應的結點,然後賦值,但是這裏通過特殊處理簡化了這些代碼。
  2. 否則,令 \(\text{trans}[u][\mathtt{i}]\) 指向 \(\text{trans}[\text{fail}[u]][\mathtt{i}]\) 的狀態。

這裏的處理是,通過 else 語句的代碼修改字典樹的結構。沒錯,它將不存在的字典樹的狀態鏈接到了失配指針的對應狀態。在原字典樹中,每一個結點代表一個字符串 \(S\),是某個模式串的前綴。而在修改字典樹結構後,儘管增加了許多轉移關係,但結點(狀態)所代表的字符串是不變的。

\(\text{trans}[S][\mathtt{c}]\) 相當於是在 \(S\) 後添加一個字符 c 變成另一個狀態 \(S'\)。如果 \(S'\) 存在,説明存在一個模式串的前綴是 \(S'\),否則我們讓 \(\text{trans}[S][\mathtt{c}]\) 指向 \(\text{trans}[\text{fail}[S]][\mathtt{c}]\)。由於 \(\text{fail}[S]\) 對應的字符串是 \(S\) 的後綴,因此 \(\text{trans}[\text{fail}[S]][\mathtt{c}]\) 對應的字符串也是 \(S'\) 的後綴。

換言之在 Trie 上跳轉的時侯,我們只會從 \(S\) 跳轉到 \(S'\),相當於匹配了一個 \(S'\);但在 AC 自動機上跳轉的時侯,我們會從 \(S\) 跳轉到 \(S'\) 的後綴,也就是説我們匹配一個字符 c,然後捨棄 \(S\) 的部分前綴。捨棄前綴顯然是能匹配的。那麼 fail 指針呢?它也是在捨棄前綴啊!試想一下,如果文本串能匹配 \(S\),顯然它也能匹配 \(S\) 的後綴。所謂的 fail 指針其實就是 \(S\) 的一個後綴集合。

tr 數組還有另一種比較簡單的理解方式:如果在位置 \(u\) 失配,我們會跳轉到 \(\text{fail}[u]\) 的位置。所以我們可能沿着 fail 數組跳轉多次才能來到下一個能匹配的位置。所以我們可以用 tr 數組直接記錄記錄下一個能匹配的位置,這樣就能節省下很多時間。

這樣修改字典樹的結構,使得匹配轉移更加完善。同時它將 fail 指針跳轉的路徑做了壓縮(就像並查集的路徑壓縮),使得本來需要跳很多次 fail 指針變成跳一次。

過程

我們將之前的 GIF 圖改一下:

AC_automation_gif_b_pro3.gif

  1. 藍色結點:BFS 遍歷到的結點 u
  2. 藍色的邊:當前結點下,AC 自動機修改字典樹結構連出的邊。
  3. 黑色的邊:AC 自動機修改字典樹結構連出的邊。
  4. 紅色的邊:當前結點求出的 fail 指針
  5. 黃色的邊:fail 指針
  6. 灰色的邊:字典樹的邊

可以發現,眾多交錯的黑色邊將字典樹變成了 字典圖。圖中省略了連向根結點的黑邊(否則會更亂)。我們重點分析一下結點 5 遍歷時的情況。我們求 \(\text{trans}[5][s]=6\) 的 fail 指針:

AC_automation_b_7.png

本來的策略是找 fail 指針,於是我們跳到 \(\text{fail}[5]=10\) 發現沒有 s 連出的字典樹的邊,於是跳到 \(\text{fail}[10]=0\),發現有 \(\text{trie}[0][s]=7\),於是 \(\text{fail}[6]=7\);但是有了黑邊、藍邊,我們跳到 \(\text{fail}[5]=10\) 之後直接走 \(\text{trans}[10][s]=7\) 就走到 \(7\) 號結點了。

這就是 build 完成的兩件事:構建 fail 指針和建立字典圖。這個字典圖也會在查詢的時候起到關鍵作用。

多模式匹配

接下來分析匹配函數 query()

實現

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
int query(char *t) {
  int u = 0, res = 0;
  for (int i = 1; t[i]; i++) {
    u = tr[u][t[i] - 'a'];  // 轉移
    for (int j = u; j && e[j] != -1; j = fail[j]) {
      res += e[j], e[j] = -1;
    }
  }
  return res;
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def query(t):
    u, res = 0, 0
    i = 1
    while t[i] == False:
        u = tr[u][t[i] - ord('a')]
        j = u
        while j == True and e[j] != -1:
            res += e[j]
            e[j] = -1
            j = fail[j]
        i += 1
    return res

解釋

這裏 \(u\) 作為字典樹上當前匹配到的結點,res 即返回的答案。循環遍歷匹配串,\(u\) 在字典樹上跟蹤當前字符。利用 fail 指針找出所有匹配的模式串,累加到答案中。然後清零。在上文中我們分析過,字典樹的結構其實就是一個 trans 函數,而構建好這個函數後,在匹配字符串的過程中,我們會捨棄部分前綴達到最低限度的匹配。fail 指針則指向了更多的匹配狀態。最後上一份圖。對於剛才的自動機:

AC_automation_b_13.png

我們從根結點開始嘗試匹配 ushersheishis,那麼 \(p\) 的變化將是:

AC_automation_gif_c.gif

  1. 紅色結點:\(p\) 結點
  2. 粉色箭頭:\(p\) 在自動機上的跳轉,
  3. 藍色的邊:成功匹配的模式串
  4. 藍色結點:示跳 fail 指針時的結點(狀態)。

效率優化

題目請參考洛谷 P5357【模板】AC 自動機(二次加強版)

因為我們的 AC 自動機中,每次匹配,會一直向 fail 邊跳來找到所有的匹配,但是這樣的效率較低,在某些題目中會被卡 T。

那麼我們如何優化呢?首先我們需要了解 fail 指針的一個性質:一個 AC 自動機中,如果只保留 fail 邊,那麼剩餘的圖一定是一棵樹。

這是顯然的,因為 fail 不會成環,且深度一定比現在低,所以得證。

而我們 AC 自動機的匹配就可以轉化為在 fail 樹上的鏈求和問題。

所以我們只需要優化一下這部分就可以了。

我們這裏提供兩種思路。

拓撲排序優化建圖

我們浪費的時間在哪裏呢?在每次都要跳 fail。如果我們可以預先記錄,最後一併求和,那麼效率就會優化。

於是我們按照 fail 樹建圖(不用真的建,只需要記錄入度):

建圖
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
void getfail()  // 實際上也可以叫 build
{
  for (int i = 0; i < 26; i++) trie[0].son[i] = 1;
  q.push(1);
  trie[1].fail = 0;
  while (!q.empty()) {
    int u = q.front();
    q.pop();
    int Fail = trie[u].fail;
    for (int i = 0; i < 26; i++) {
      int v = trie[u].son[i];
      if (!v) {
        trie[u].son[i] = trie[Fail].son[i];
        continue;
      }
      trie[v].fail = trie[Fail].son[i];
      indeg[trie[Fail].son[i]]++;  // 修改點在這裏,增加了入度記錄
      q.push(v);
    }
  }
}

然後我們在查詢的時候就可以只為找到節點的 ans 打上標記,在最後再用拓撲排序求出答案。

查詢
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
void query(char *s) {
  int u = 1, len = strlen(s);
  for (int i = 0; i < len; i++) u = trie[u].son[s[i] - 'a'], trie[u].ans++;
}

void topu() {
  for (int i = 1; i <= cnt; i++)
    if (!indeg[i]) q.push(i);
  while (!q.empty()) {
    int fr = q.front();
    q.pop();
    vis[trie[fr].flag] = trie[fr].ans;
    int u = trie[fr].fail;
    trie[u].ans += trie[fr].ans;
    if (!(--indeg[u])) q.push(u);
  }
}

主函數里這麼寫:

1
2
3
4
5
6
7
8
int main() {
  // do_something();
  scanf("%s", s);
  query(s);
  topu();
  for (int i = 1; i <= n; i++) cout << vis[rev[i]] << std::endl;
  // do_another_thing();
}
完整代碼

Luogu P5357【模板】AC 自動機(二次加強版)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// Code by rickyxrc | https://www.luogu.com.cn/record/115706921
#include <bits/stdc++.h>
#define maxn 8000001
using namespace std;
char s[maxn];
int n, cnt, vis[maxn], rev[maxn], indeg[maxn], ans;

struct trie_node {
  int son[27];
  int fail;
  int flag;
  int ans;

  void init() {
    memset(son, 0, sizeof(son));
    fail = flag = 0;
  }
} trie[maxn];

queue<int> q;

void init() {
  for (int i = 0; i <= cnt; i++) trie[i].init();
  for (int i = 1; i <= n; i++) vis[i] = 0;
  cnt = 1;
  ans = 0;
}

void insert(char *s, int num) {
  int u = 1, len = strlen(s);
  for (int i = 0; i < len; i++) {
    int v = s[i] - 'a';
    if (!trie[u].son[v]) trie[u].son[v] = ++cnt;
    u = trie[u].son[v];
  }
  if (!trie[u].flag) trie[u].flag = num;
  rev[num] = trie[u].flag;
  return;
}

void getfail(void) {
  for (int i = 0; i < 26; i++) trie[0].son[i] = 1;
  q.push(1);
  trie[1].fail = 0;
  while (!q.empty()) {
    int u = q.front();
    q.pop();
    int Fail = trie[u].fail;
    for (int i = 0; i < 26; i++) {
      int v = trie[u].son[i];
      if (!v) {
        trie[u].son[i] = trie[Fail].son[i];
        continue;
      }
      trie[v].fail = trie[Fail].son[i];
      indeg[trie[Fail].son[i]]++;
      q.push(v);
    }
  }
}

void topu() {
  for (int i = 1; i <= cnt; i++)
    if (!indeg[i]) q.push(i);
  while (!q.empty()) {
    int fr = q.front();
    q.pop();
    vis[trie[fr].flag] = trie[fr].ans;
    int u = trie[fr].fail;
    trie[u].ans += trie[fr].ans;
    if (!(--indeg[u])) q.push(u);
  }
}

void query(char *s) {
  int u = 1, len = strlen(s);
  for (int i = 0; i < len; i++) u = trie[u].son[s[i] - 'a'], trie[u].ans++;
}

int main() {
  scanf("%d", &n);
  init();
  for (int i = 1; i <= n; i++) scanf("%s", s), insert(s, i);
  getfail();
  scanf("%s", s);
  query(s);
  topu();
  for (int i = 1; i <= n; i++) cout << vis[rev[i]] << std::endl;
  return 0;
}

子樹求和

和拓撲排序的思路接近,我們預先將子樹求和,詢問時直接累加和值即可。

完整代碼請見總結模板 3。

AC 自動機上 DP

這部分將以 P2292 [HNOI2004] L 語言 為例題講解。

一看題,不難想到一個 naive 的思路:建立 AC 自動機,在 AC 自動機上對於所有 fail 指針的子串轉移,最後取最大值得到答案。

主要代碼如下(若不熟悉代碼中的類型定義可以跳到末尾的完整代碼):

查詢部分主要代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
void query(char *s) {
  int u = 1, len = strlen(s), l = 0;
  for (int i = 0; i < len; i++) {
    int v = s[i] - 'a';
    int k = trie[u].son[v];
    while (k > 1) {
      if (trie[k].flag && (dp[i - trie[k].len] || i - trie[k].len == -1))
        dp[i] = dp[i - trie[k].len] + trie[k].len;
      k = trie[k].fail;
    }
    u = trie[u].son[v];
  }
}

主函數里取 max 即可。

1
for (int i = 0, e = strlen(T); i < e; i++) mx = std::max(mx, dp[i]);

但是這樣的思路複雜度不是線性(因為要跳每個節點的 fail),會被 subtask#2 卡到 T,所以我們需要一個優化的思路。

我們再看看題目的特殊性質,我們發現所有單詞的長度只有 \(20\),所以可以想到狀態壓縮優化。

具體怎麼優化呢?我們發現,目前的時間瓶頸主要在跳 fail 這一步,如果我們可以將這一步優化到 \(O(1)\),就可以保證整個問題在嚴格線性的時間內被解出。

那我們就將前 \(20\) 位字母中,可能的子串長度存下來,並壓縮到狀態中,存在每個子節點中。

那麼我們在 buildfail 的時候就可以這麼寫:

構建 fail 指針
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
void getfail(void) {
  for (int i = 0; i < 26; i++) trie[0].son[i] = 1;
  q.push(1);
  trie[1].fail = 0;
  while (!q.empty()) {
    int u = q.front();
    q.pop();
    int Fail = trie[u].fail;
    // 對狀態的更新在這裏
    trie[u].stat = trie[Fail].stat;
    if (trie[u].flag) trie[u].stat |= 1 << trie[u].depth;
    for (int i = 0; i < 26; i++) {
      int v = trie[u].son[i];
      if (!v)
        trie[u].son[i] = trie[Fail].son[i];
      else {
        trie[v].depth = trie[u].depth + 1;
        trie[v].fail = trie[Fail].son[i];
        q.push(v);
      }
    }
  }
}

然後查詢時就可以去掉跳 fail 的循環,將代碼簡化如下:

查詢
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
int query(char *s) {
  int u = 1, len = strlen(s), mx = 0;
  unsigned st = 1;
  for (int i = 0; i < len; i++) {
    int v = s[i] - 'a';
    u = trie[u].son[v];
    // 因為往下跳了一位每一位的長度都+1
    st <<= 1;
    // 這裏的 & 值是狀壓 dp 的使用,代表兩個長度集的交非空
    if (trie[u].stat & st) st |= 1, mx = i + 1;
  }
  return mx;
}

我們的 trie[u].stat 維護的是從 u 節點開始,整條 fail 鏈上的長度集(因為長度集小於 32 所以不影響),而 st 則維護的是查詢字符串走到現在,前 32 位(因為狀態壓縮自然溢出)的長度集。

& 值不為 0,則代表兩個長度集的交集非空,我們此時就找到了一個匹配。

完整代碼

P2292 [HNOI2004] L 語言

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
// Code by rickyxrc | https://www.luogu.com.cn/record/115806238
#include <stdio.h>
#include <string.h>

#include <queue>
#define maxn 3000001
char T[maxn];
int n, cnt, vis[maxn], ans, m, dp[maxn];

struct trie_node {
  int son[26];
  int fail, flag, depth;
  unsigned stat;

  void init() {
    memset(son, 0, sizeof(son));
    fail = flag = depth = 0;
  }
} trie[maxn];

std::queue<int> q;

void init() {
  for (int i = 0; i <= cnt; i++) trie[i].init();
  for (int i = 1; i <= n; i++) vis[i] = 0;
  cnt = 1;
  ans = 0;
}

void insert(char *s, int num) {
  int u = 1, len = strlen(s);
  for (int i = 0; i < len; i++) {
    // trie[u].depth = i + 1;
    int v = s[i] - 'a';
    if (!trie[u].son[v]) trie[u].son[v] = ++cnt;
    u = trie[u].son[v];
  }
  trie[u].flag = num;
  // trie[u].stat = 1;
  // printf("set %d stat %d\n", u-1, 1);
  return;
}

void getfail(void) {
  for (int i = 0; i < 26; i++) trie[0].son[i] = 1;
  q.push(1);
  trie[1].fail = 0;
  while (!q.empty()) {
    int u = q.front();
    q.pop();
    int Fail = trie[u].fail;
    trie[u].stat = trie[Fail].stat;
    if (trie[u].flag) trie[u].stat |= 1 << trie[u].depth;
    for (int i = 0; i < 26; i++) {
      int v = trie[u].son[i];
      if (!v)
        trie[u].son[i] = trie[Fail].son[i];
      else {
        trie[v].depth = trie[u].depth + 1;
        trie[v].fail = trie[Fail].son[i];
        q.push(v);
      }
    }
  }
}

int query(char *s) {
  int u = 1, len = strlen(s), mx = 0;
  unsigned st = 1;
  for (int i = 0; i < len; i++) {
    int v = s[i] - 'a';
    u = trie[u].son[v];
    st <<= 1;
    if (trie[u].stat & st) st |= 1, mx = i + 1;
  }
  return mx;
}

int main() {
  scanf("%d%d", &n, &m);
  init();
  for (int i = 1; i <= n; i++) {
    scanf("%s", T);
    insert(T, i);
  }
  getfail();
  for (int i = 1; i <= m; i++) {
    scanf("%s", T);
    printf("%d\n", query(T));
  }
}

總結

希望大家看懂了文章。

時間複雜度:定義 \(|s_i|\) 是模板串的長度,\(|S|\) 是文本串的長度,\(|\Sigma|\) 是字符集的大小(常數,一般為 26)。如果連了 trie 圖,時間複雜度就是 \(O(\sum|s_i|+n|\Sigma|+|S|)\),其中 \(n\) 是 AC 自動機中結點的數目,並且最大可以達到 \(O(\sum|s_i|)\)。如果不連 trie 圖,並且在構建 fail 指針的時候避免遍歷到空兒子,時間複雜度就是 \(O(\sum|s_i|+|S|)\)

模板 1

Luogu P3808【模板】AC 自動機(簡單版)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 6;
int n;

namespace AC {
int tr[N][26], tot;
int e[N], fail[N];

void insert(char *s) {
  int u = 0;
  for (int i = 1; s[i]; i++) {
    if (!tr[u][s[i] - 'a']) tr[u][s[i] - 'a'] = ++tot;  // 如果没有则插入新节点
    u = tr[u][s[i] - 'a'];                              // 搜索下一个节点
  }
  e[u]++;  // 尾为节点 u 的串的个数
}

queue<int> q;

void build() {
  for (int i = 0; i < 26; i++)
    if (tr[0][i]) q.push(tr[0][i]);
  while (q.size()) {
    int u = q.front();
    q.pop();
    for (int i = 0; i < 26; i++) {
      if (tr[u][i]) {
        fail[tr[u][i]] =
            tr[fail[u]][i];  // fail数组:同一字符可以匹配的其他位置
        q.push(tr[u][i]);
      } else
        tr[u][i] = tr[fail[u]][i];
    }
  }
}

int query(char *t) {
  int u = 0, res = 0;
  for (int i = 1; t[i]; i++) {
    u = tr[u][t[i] - 'a'];  // 转移
    for (int j = u; j && e[j] != -1; j = fail[j]) {
      res += e[j], e[j] = -1;
    }
  }
  return res;
}
}  // namespace AC

char s[N];

int main() {
  scanf("%d", &n);
  for (int i = 1; i <= n; i++) scanf("%s", s + 1), AC::insert(s);
  scanf("%s", s + 1);
  AC::build();
  printf("%d", AC::query(s));
  return 0;
}
模板 2

Luogu P3796【模板】AC 自動機(加強版)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <bits/stdc++.h>
using namespace std;
const int N = 156, L = 1e6 + 6;

namespace AC {
const int SZ = N * 80;
int tot, tr[SZ][26];
int fail[SZ], idx[SZ], val[SZ];
int cnt[N];  // 记录第 i 个字符串的出现次数

void init() {
  memset(fail, 0, sizeof(fail));
  memset(tr, 0, sizeof(tr));
  memset(val, 0, sizeof(val));
  memset(cnt, 0, sizeof(cnt));
  memset(idx, 0, sizeof(idx));
  tot = 0;
}

void insert(char *s, int id) {  // id 表示原始字符串的编号
  int u = 0;
  for (int i = 1; s[i]; i++) {
    if (!tr[u][s[i] - 'a']) tr[u][s[i] - 'a'] = ++tot;
    u = tr[u][s[i] - 'a'];  // 转移
  }
  idx[u] = id;  // 以 u 为结尾的字符串编号为 idx[u]
}

queue<int> q;

void build() {
  for (int i = 0; i < 26; i++)
    if (tr[0][i]) q.push(tr[0][i]);
  while (q.size()) {
    int u = q.front();
    q.pop();
    for (int i = 0; i < 26; i++) {
      if (tr[u][i]) {
        fail[tr[u][i]] =
            tr[fail[u]][i];  // fail数组:同一字符可以匹配的其他位置
        q.push(tr[u][i]);
      } else
        tr[u][i] = tr[fail[u]][i];
    }
  }
}

int query(char *t) {  // 返回最大的出现次数
  int u = 0, res = 0;
  for (int i = 1; t[i]; i++) {
    u = tr[u][t[i] - 'a'];
    for (int j = u; j; j = fail[j]) val[j]++;
  }
  for (int i = 0; i <= tot; i++)
    if (idx[i]) res = max(res, val[i]), cnt[idx[i]] = val[i];
  return res;
}
}  // namespace AC

int n;
char s[N][100], t[L];

int main() {
  while (~scanf("%d", &n)) {
    if (n == 0) break;
    AC::init();  // 数组清零
    for (int i = 1; i <= n; i++)
      scanf("%s", s[i] + 1), AC::insert(s[i], i);  // 需要记录该字符串的序号
    AC::build();
    scanf("%s", t + 1);
    int x = AC::query(t);
    printf("%d\n", x);
    for (int i = 1; i <= n; i++)
      if (AC::cnt[i] == x) printf("%s\n", s[i] + 1);
  }
  return 0;
}
模版 3

Luogu P5357【模板】AC 自動機(二次加強版)

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <deque>
#include <iostream>

void promote() {
  std::ios::sync_with_stdio(0);
  std::cin.tie(0);
  std::cout.tie(0);
  return;
}

typedef char chr;
typedef std::deque<int> dic;

const int maxN = 2e5;
const int maxS = 2e5;
const int maxT = 2e6;

int n;
chr s[maxS + 10];
chr t[maxT + 10];
int cnt[maxN + 10];

struct AhoCorasickAutomaton {
  struct Node {
    int son[30];
    int val;
    int fail;
    int head;
    dic index;
  } node[maxS + 10];

  struct Edge {
    int head;
    int next;
  } edge[maxS + 10];

  int root;
  int ncnt;
  int ecnt;

  void Insert(chr *str, int i) {
    int u = root;
    for (int i = 1; str[i]; i++) {
      if (node[u].son[str[i] - 'a' + 1] == 0)
        node[u].son[str[i] - 'a' + 1] = ++ncnt;
      u = node[u].son[str[i] - 'a' + 1];
    }
    node[u].index.push_back(i);
    return;
  }

  void Build() {
    dic q;
    for (int i = 1; i <= 26; i++)
      if (node[root].son[i]) q.push_back(node[root].son[i]);
    while (!q.empty()) {
      int u = q.front();
      q.pop_front();
      for (int i = 1; i <= 26; i++) {
        if (node[u].son[i]) {
          node[node[u].son[i]].fail = node[node[u].fail].son[i];
          q.push_back(node[u].son[i]);
        } else {
          node[u].son[i] = node[node[u].fail].son[i];
        }
      }
    }
    return;
  }

  void Query(chr *str) {
    int u = root;
    for (int i = 1; str[i]; i++) {
      u = node[u].son[str[i] - 'a' + 1];
      node[u].val++;
    }
    return;
  }

  void addEdge(int tail, int head) {
    ecnt++;
    edge[ecnt].head = head;
    edge[ecnt].next = node[tail].head;
    node[tail].head = ecnt;
    return;
  }

  void DFS(int u) {
    for (int e = node[u].head; e; e = edge[e].next) {
      int v = edge[e].head;
      DFS(v);
      node[u].val += node[v].val;
    }
    for (auto i : node[u].index) cnt[i] += node[u].val;
    return;
  }

  void FailTree() {
    for (int u = 1; u <= ncnt; u++) addEdge(node[u].fail, u);
    DFS(root);
    return;
  }
} ACM;

int main() {
  std::cin >> n;
  for (int i = 1; i <= n; i++) {
    std::cin >> (s + 1);
    ACM.Insert(s, i);
  }
  ACM.Build();
  std::cin >> (t + 1);
  ACM.Query(t);
  ACM.FailTree();
  for (int i = 1; i <= n; i++) std::cout << cnt[i] << '\n';
  return 0;
}

拓展

確定有限狀態自動機

如果大家理解了上面的講解,那麼作為拓展延伸,文末我們簡單介紹一下 自動機KMP 自動機。(現在你再去看自動機的定義就會好懂很多啦)

有限狀態自動機(Deterministic Finite Automaton,DFA)是由

  1. 狀態集合 \(Q\)
  2. 字符集 \(\Sigma\)
  3. 狀態轉移函數 \(\delta:Q\times \Sigma \to Q\),即 \(\delta(q,\sigma)=q',\ q,q'\in Q,\sigma\in \Sigma\)
  4. 一個開始狀態 \(s\in Q\)
  5. 一個接收的狀態集合 \(F\subseteq Q\)

組成的五元組 \((Q,\Sigma,\delta,s,F)\)

那這東西你用 AC 自動機理解,狀態集合就是字典樹(圖)的結點;字符集就是 az(或者更多);狀態轉移函數就是 \(\text{trans}(u,c)\) 的函數(即 \(\text{trans}[u][c]\));開始狀態就是字典樹的根結點;接收狀態就是你在字典樹中標記的字符串結尾結點組成的集合。

KMP 自動機

KMP 自動機就是一個不斷讀入待匹配串,每次匹配時走到接受狀態的 DFA。如果共有 \(m\) 個狀態,第 \(i\) 個狀態表示已經匹配了前 \(i\) 個字符。那麼我們定義 \(\text{trans}_{i,c}\) 表示狀態 \(i\) 讀入字符 \(c\) 後到達的狀態,\(\text{next}_{i}\) 表示 prefix function,則有:

\[ \text{trans}_{i,c} = \begin{cases} i + 1, & \text{if }b_{i} = c \\[2ex] \text{trans}_{\text{next}_{i},c}, & \text{otherwise} \end{cases} \]

(約定 \(\text{next}_{0}=0\)

我們發現 \(\text{trans}_{i}\) 只依賴於之前的值,所以可以跟 KMP 一起求出來。(一些細節:走到接受狀態之後立即轉移到該狀態的 next)

時間和空間複雜度:\(O(m|\Sigma|)\)

對比之下,AC 自動機其實就是 Trie 上的自動機。(雖然一開始丟給你這句話可能不知所措)