跳转至

序列自動機

在閲讀本文之前,請先閲讀 自動機

定義

序列自動機是接受且僅接受一個字符串的子序列的自動機。

本文中用 \(s\) 代指這個字符串。

狀態

\(s\) 包含 \(n\) 個字符,那麼序列自動機包含 \(n+1\) 個狀態。

\(t\)\(s\) 的一個子序列,那麼 \(\delta(start, t)\)\(t\)\(s\) 中第一次出現時末端的位置。

也就是説,一個狀態 \(i\) 表示前綴 \(s[1..i]\) 的子序列與前綴 \(s[1..i-1]\) 的子序列的差集。

序列自動機上的所有狀態都是接受狀態。

轉移

由狀態定義可以得到,\(\delta(u, c)=\min\{i|i>u,s[i]=c\}\),也就是字符 \(c\) 下一次出現的位置。

為什麼是「下一次」出現的位置呢?因為若 \(i>j\),後綴 \(s[i..|s|]\) 的子序列是後綴 \(s[j..|s|]\) 的子序列的子集,一定是選儘量靠前的最優。

實現

從後向前掃描,過程中維護每個字符最前的出現位置:

\[ \begin{array}{ll} 1 & \textbf{Input. } \text{A string } S\\ 2 & \textbf{Output. } \text{The state transition of the sequence automaton of }S \\ 3 & \textbf{Method. } \\ 4 & \textbf{for }c\in\Sigma\\ 5 & \qquad next[c]\gets null\\ 6 & \textbf{for }i\gets|S|\textbf{ downto }1\\ 7 & \qquad next[S[i]]\gets i\\ 8 & \qquad \textbf{for }c\in\Sigma\\ 9 & \qquad\qquad \delta(i-1,c)\gets next[c]\\ 10 & \textbf{return }\delta \end{array} \]

這樣構建的複雜度是 \(O(n|\Sigma|)\)

例題

「HEOI2015」最短不公共子串

給你兩個由小寫英文字母組成的串 \(A\)\(B\),求:

  1. \(A\) 的一個最短的子串,它不是 \(B\) 的子串;
  2. \(A\) 的一個最短的子串,它不是 \(B\) 的子序列;
  3. \(A\) 的一個最短的子序列,它不是 \(B\) 的子串;
  4. \(A\) 的一個最短的子序列,它不是 \(B\) 的子序列。

    \(1\le |A|, |B|\le 2000\)

題解

這題的 (1) 和 (3) 兩問需要後綴自動機,而且做法類似,在這裏只講解 (2) 和 (4) 兩問。

(2) 比較簡單,枚舉 A 的子串輸入進 B 的序列自動機,若不接受則計入答案。

(4) 需要 DP。令 \(f(i, j)\) 表示在 A 的序列自動機中處於狀態 \(i\),在 B 的序列自動機中處於狀態 \(j\),需要再添加多少個字符能夠不是公共子序列。

\(f(i, null)=0\)

\(f(i, j)=\min\limits_{\delta_A(i,c)\ne null}f(\delta_A(i, c), \delta_B(j, c))+1\)

參考代碼
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>

using namespace std;

const int N = 2005;

char s[N], t[N];
int na[N][26], nb[N][26], nxt[26];
int n, m, a[N], b[N], tot = 1, p = 1, f[N][N << 1];

struct SAM {
  int par, ch[26], len;
} sam[N << 1];

void insert(int x) {
  int np = ++tot;  // 新节点
  sam[np].len = sam[p].len + 1;
  while (p && !sam[p].ch[x]) {
    sam[p].ch[x] = np;
    p = sam[p].par;
  }
  if (p == 0)
    sam[np].par = 1;
  else {
    int q = sam[p].ch[x];
    if (sam[q].len == sam[p].len + 1)
      sam[np].par = q;
    else {
      int nq = ++tot;
      sam[nq].len = sam[p].len + 1;
      memcpy(sam[nq].ch, sam[q].ch, sizeof(sam[q].ch));
      sam[nq].par = sam[q].par;
      sam[q].par = sam[np].par = nq;
      while (p && sam[p].ch[x] == q) {
        sam[p].ch[x] = nq;
        p = sam[p].par;
      }
    }
  }
  p = np;
}

int main() {
  scanf("%s%s", s + 1, t + 1);

  n = strlen(s + 1);
  m = strlen(t + 1);

  for (int i = 1; i <= n; ++i) a[i] = s[i] - 'a';
  for (int i = 1; i <= m; ++i) b[i] = t[i] - 'a';

  for (int i = 1; i <= m; ++i) insert(b[i]);

  // nxt[S[i]]<-i
  for (int i = 0; i < 26; ++i) nxt[i] = n + 1;
  for (int i = n; i >= 0; --i) {
    memcpy(na[i], nxt, sizeof(nxt));
    nxt[a[i]] = i;
  }

  for (int i = 0; i < 26; ++i) nxt[i] = m + 1;
  for (int i = m; i >= 0; --i) {
    memcpy(nb[i], nxt, sizeof(nxt));
    nxt[b[i]] = i;
  }

  // 四种情况计算答案
  //  1
  int ans = N;
  for (int l = 1; l <= n; ++l) {
    for (int r = l, u = 1; r <= n; ++r) {
      u = sam[u].ch[a[r]];
      if (!u) {
        ans = min(ans, r - l + 1);
        break;
      }
    }
  }

  printf("%d\n", ans == N ? -1 : ans);

  // 2
  ans = N;

  for (int l = 1; l <= n; ++l) {
    for (int r = l, u = 0; r <= n; ++r) {
      u = nb[u][a[r]];
      if (u == m + 1) {
        ans = min(ans, r - l + 1);
        break;
      }
    }
  }

  printf("%d\n", ans == N ? -1 : ans);

  // 3
  for (int i = n; i >= 0; --i) {
    for (int j = 1; j <= tot; ++j) {
      f[i][j] = N;
      for (int c = 0; c < 26; ++c) {
        int u = na[i][c];
        int v = sam[j].ch[c];
        if (u <= n) f[i][j] = min(f[i][j], f[u][v] + 1);
      }
    }
  }

  printf("%d\n", f[0][1] == N ? -1 : f[0][1]);

  // 4
  memset(f, 0, sizeof(f));

  for (int i = n; i >= 0; --i) {
    for (int j = 0; j <= m; ++j) {
      f[i][j] = N;
      for (int c = 0; c < 26; ++c) {
        int u = na[i][c];
        int v = nb[j][c];
        if (u <= n) f[i][j] = min(f[i][j], f[u][v] + 1);
      }
    }
  }

  printf("%d\n", f[0][0] == N ? -1 : f[0][0]);

  return 0;
}