跳转至

A*

本頁面將簡要介紹 A * 算法。

定義

A * 搜索算法(英文:A*search algorithm,A * 讀作 A-star),簡稱 A * 算法,是一種在圖形平面上,對於有多個節點的路徑求出最低通過成本的算法。它屬於圖遍歷(英文:Graph traversal)和最佳優先搜索算法(英文:Best-first search),亦是 BFS 的改進。

過程

定義起點 \(s\),終點 \(t\),從起點(初始狀態)開始的距離函數 \(g(x)\),到終點(最終狀態)的距離函數 \(h(x)\)\(h^{\ast}(x)\)1,以及每個點的估價函數 \(f(x)=g(x)+h(x)\)

A * 算法每次從優先隊列中取出一個 \(f\) 最小的元素,然後更新相鄰的狀態。

如果 \(h\leq h*\),則 A * 算法能找到最優解。

上述條件下,如果 \(h\) 滿足三角形不等式,則 A * 算法不會將重複結點加入隊列。

\(h=0\) 時,A * 算法變為 Dijkstra;當 \(h=0\) 並且邊權為 \(1\) 時變為 BFS

例題

八數碼

題目大意:在 \(3\times 3\) 的棋盤上,擺有八個棋子,每個棋子上標有 \(1\)\(8\) 的某一數字。棋盤中留有一個空格,空格用 \(0\) 來表示。空格周圍的棋子可以移到空格中,這樣原來的位置就會變成空格。給出一種初始佈局和目標佈局(為了使題目簡單,設目標狀態如下),找到一種從初始佈局到目標佈局最少步驟的移動方法。

1
2
3
    123
    804
    765
解題思路

\(h\) 函數可以定義為,不在應該在的位置的數字個數。

容易發現 \(h\) 滿足以上兩個性質,此題可以使用 A * 算法求解。

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
#include <set>
using namespace std;
const int dx[4] = {1, -1, 0, 0}, dy[4] = {0, 0, 1, -1};
int fx, fy;
char ch;

struct matrix {
  int a[5][5];

  bool operator<(matrix x) const {
    for (int i = 1; i <= 3; i++)
      for (int j = 1; j <= 3; j++)
        if (a[i][j] != x.a[i][j]) return a[i][j] < x.a[i][j];
    return false;
  }
} f, st;

int h(matrix a) {
  int ret = 0;
  for (int i = 1; i <= 3; i++)
    for (int j = 1; j <= 3; j++)
      if (a.a[i][j] != st.a[i][j]) ret++;
  return ret;
}

struct node {
  matrix a;
  int t;

  bool operator<(node x) const { return t + h(a) > x.t + h(x.a); }
} x;

priority_queue<node> q;  // 搜索队列
set<matrix> s;           // 防止搜索队列重复

int main() {
  st.a[1][1] = 1;  // 定义标准表
  st.a[1][2] = 2;
  st.a[1][3] = 3;
  st.a[2][1] = 8;
  st.a[2][2] = 0;
  st.a[2][3] = 4;
  st.a[3][1] = 7;
  st.a[3][2] = 6;
  st.a[3][3] = 5;
  for (int i = 1; i <= 3; i++)  // 输入
    for (int j = 1; j <= 3; j++) {
      scanf(" %c", &ch);
      f.a[i][j] = ch - '0';
    }
  q.push({f, 0});
  while (!q.empty()) {
    x = q.top();
    q.pop();
    if (!h(x.a)) {  // 判断是否与标准矩阵一致
      printf("%d\n", x.t);
      return 0;
    }
    for (int i = 1; i <= 3; i++)
      for (int j = 1; j <= 3; j++)
        if (!x.a.a[i][j]) fx = i, fy = j;  // 查找空格子(0号点)的位置
    for (int i = 0; i < 4; i++) {  // 对四种移动方式分别进行搜索
      int xx = fx + dx[i], yy = fy + dy[i];
      if (1 <= xx && xx <= 3 && 1 <= yy && yy <= 3) {
        swap(x.a.a[fx][fy], x.a.a[xx][yy]);
        if (!s.count(x.a))
          s.insert(x.a),
              q.push({x.a, x.t + 1});  // 这样移动后,将新的情况放入搜索队列中
        swap(x.a.a[fx][fy], x.a.a[xx][yy]);  // 如果不这样移动的情况
      }
    }
  }
  return 0;
}

注:對於 k 短路問題,原題已經可以構造出數據使得 A* 算法無法通過,故本題思路僅供參考,A* 算法非正解,正解為可持久化可並堆做法,請移步 k 短路問題

k 短路

按順序求一個有向圖上從結點 \(s\) 到結點 \(t\) 的所有路徑最小的前任意多(不妨設為 \(k\))個。

解題思路

很容易發現,這個問題很容易轉化成用 A * 算法解決問題的標準程式。

初始狀態為處於結點 \(s\),最終狀態為處於結點 \(t\),距離函數為從 \(s\) 到當前結點已經走過的距離,估價函數為從當前結點到結點 \(t\) 至少要走過的距離,也就是當前結點到結點 \(t\) 的最短路。

就這樣,我們在預處理的時候反向建圖,計算出結點 \(t\) 到所有點的最短路,然後將初始狀態塞入優先隊列,每次取出 \(f(x)=g(x)+h(x)\) 最小的一項,計算出其所連結點的信息並將其也塞入隊列。當你第 \(k\) 次走到結點 \(t\) 時,也就算出了結點 \(s\) 到結點 \(t\)\(k\) 短路。

由於設計的距離函數和估價函數,每個狀態需要存儲兩個參數,當前結點 \(x\) 和已經走過的距離 \(v\)

我們可以在此基礎上加一點小優化:由於只需要求出第 \(k\) 短路,所以當我們第 \(k+1\) 次或以上走到該結點時,直接跳過該狀態。因為前面的 \(k\) 次走到這個點的時候肯定能因此構造出 \(k\) 條路徑,所以之後再加邊更無必要。

參考代碼
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
const int maxn = 5010;
const int maxm = 400010;
const double inf = 2e9;
int n, m, k, u, v, cur, h[maxn], nxt[maxm], p[maxm], cnt[maxn], ans;
int cur1, h1[maxn], nxt1[maxm], p1[maxm];
double e, ww, w[maxm], f[maxn];
double w1[maxm];
bool tf[maxn];

void add_edge(int x, int y, double z) {  // 正向建图函数
  cur++;
  nxt[cur] = h[x];
  h[x] = cur;
  p[cur] = y;
  w[cur] = z;
}

void add_edge1(int x, int y, double z) {  // 反向建图函数
  cur1++;
  nxt1[cur1] = h1[x];
  h1[x] = cur1;
  p1[cur1] = y;
  w1[cur1] = z;
}

struct node {  // 使用A*时所需的结构体
  int x;
  double v;

  bool operator<(node a) const { return v + f[x] > a.v + f[a.x]; }
};

priority_queue<node> q;

struct node2 {  // 计算t到所有结点最短路时所需的结构体
  int x;
  double v;

  bool operator<(node2 a) const { return v > a.v; }
} x;

priority_queue<node2> Q;

int main() {
  scanf("%d%d%lf", &n, &m, &e);
  while (m--) {
    scanf("%d%d%lf", &u, &v, &ww);
    add_edge(u, v, ww);   // 正向建图
    add_edge1(v, u, ww);  // 反向建图
  }
  for (int i = 1; i < n; i++) f[i] = inf;
  Q.push({n, 0});
  while (!Q.empty()) {  // 计算t到所有结点的最短路
    x = Q.top();
    Q.pop();
    if (tf[x.x]) continue;
    tf[x.x] = true;
    f[x.x] = x.v;
    for (int j = h1[x.x]; j; j = nxt1[j]) Q.push({p1[j], x.v + w1[j]});
  }
  k = (int)e / f[1];
  q.push({1, 0});
  while (!q.empty()) {  // 使用A*算法
    node x = q.top();
    q.pop();
    cnt[x.x]++;
    if (x.x == n) {
      e -= x.v;
      if (e < 0) {
        printf("%d\n", ans);
        return 0;
      }
      ans++;
    }
    for (int j = h[x.x]; j; j = nxt[j])
      if (cnt[p[j]] <= k && x.v + w[j] <= e) q.push({p[j], x.v + w[j]});
  }
  printf("%d\n", ans);
  return 0;
}

參考資料與註釋


  1. 此處的 h 意為 heuristic。詳見 啓發式搜索 - 維基百科A*search algorithm - Wikipedia 的 Bounded relaxation 一節。