跳转至

記憶化搜索

定義

記憶化搜索是一種通過記錄已經遍歷過的狀態的信息,從而避免對同一狀態重複遍歷的搜索實現方式。

因為記憶化搜索確保了每個狀態只訪問一次,它也是一種常見的動態規劃實現方式。

引入

[NOIP2005] 採藥

山洞裏有 \(M\) 株不同的草藥,採每一株都需要一些時間 \(t_i\),每一株也有它自身的價值 \(v_i\)。給你一段時間 \(T\),在這段時間裏,你可以採到一些草藥。讓採到的草藥的總價值最大。

\(1 \leq T \leq 10^3\)\(1 \leq t_i,v_i,M \leq 100\)

樸素的 DFS 做法

很容易實現這樣一個樸素的搜索做法:在搜索時記錄下當前準備選第幾個物品、剩餘的時間是多少、已經獲得的價值是多少這三個參數,然後枚舉當前物品是否被選,轉移到相應的狀態。

實現
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
int n, t;
int tcost[103], mget[103];
int ans = 0;

void dfs(int pos, int tleft, int tans) {
  if (tleft < 0) return;
  if (pos == n + 1) {
    ans = max(ans, tans);
    return;
  }
  dfs(pos + 1, tleft, tans);
  dfs(pos + 1, tleft - tcost[pos], tans + mget[pos]);
}

int main() {
  cin >> t >> n;
  for (int i = 1; i <= n; i++) cin >> tcost[i] >> mget[i];
  dfs(1, t, 0);
  cout << ans << endl;
  return 0;
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
tcost = [0] * 103
mget = [0] * 103
ans = 0
def dfs(pos, tleft, tans):
    global ans
    if tleft < 0:
        return
    if pos == n + 1:
        ans = max(ans, tans)
        return
    dfs(pos + 1, tleft, tans)
    dfs(pos + 1, tleft - tcost[pos], tans + mget[pos])
t, n = map(lambda x:int(x), input().split())
for i in range(1, n + 1):
    tcost[i], mget[i] = map(lambda x:int(x), input().split())
dfs(1, t, 0)
print(ans)

這種做法的時間複雜度是指數級別的,並不能通過本題。

優化

上面的做法為什麼效率低下呢?因為同一個狀態會被訪問多次。

如果我們每查詢完一個狀態後將該狀態的信息存儲下來,再次需要訪問這個狀態就可以直接使用之前計算得到的信息,從而避免重複計算。這充分利用了動態規劃中很多問題具有大量重疊子問題的特點,屬於用空間換時間的「記憶化」思想。

具體到本題上,我們在樸素的 DFS 的基礎上,增加一個數組 mem 來記錄每個 dfs(pos,tleft) 的返回值。剛開始把 mem 中每個值都設成 -1(代表沒求解過)。每次需要訪問一個狀態時,如果相應狀態的值在 mem 中為 -1,則遞歸訪問該狀態。否則我們直接使用 mem 中已經存儲過的值即可。

通過這樣的處理,我們確保了每個狀態只會被訪問一次,因此該算法的的時間複雜度為 \(O(TM)\)

實現
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
int n, t;
int tcost[103], mget[103];
int mem[103][1003];

int dfs(int pos, int tleft) {
  if (mem[pos][tleft] != -1)
    return mem[pos][tleft];  // 已經訪問過的狀態,直接返回之前記錄的值
  if (pos == n + 1) return mem[pos][tleft] = 0;
  int dfs1, dfs2 = -INF;
  dfs1 = dfs(pos + 1, tleft);
  if (tleft >= tcost[pos])
    dfs2 = dfs(pos + 1, tleft - tcost[pos]) + mget[pos];  // 狀態轉移
  return mem[pos][tleft] = max(dfs1, dfs2);  // 最後將當前狀態的值存下來
}

int main() {
  memset(mem, -1, sizeof(mem));
  cin >> t >> n;
  for (int i = 1; i <= n; i++) cin >> tcost[i] >> mget[i];
  cout << dfs(1, t) << endl;
  return 0;
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
tcost = [0] * 103
mget = [0] * 103
mem = [[-1 for i in range(1003)] for j in range(103)]
def dfs(pos, tleft):
    if mem[pos][tleft] != -1:
        return mem[pos][tleft]
    if pos == n + 1:
        mem[pos][tleft] = 0
        return mem[pos][tleft]
    dfs1 = dfs2 = -INF
    dfs1 = dfs(pos + 1, tleft)
    if tleft >= tcost[pos]:
        dfs2 = dfs(pos + 1, tleft - tcost[pos]) + mget[pos]
    mem[pos][tleft] = max(dfs1, dfs2)
    return mem[pos][tleft]
t, n = map(lambda x:int(x), input().split())
for i in range(1, n + 1):
    tcost[i], mget[i] = map(lambda x:int(x), input().split())
print(dfs(1, t))

與遞推的聯繫與區別

在求解動態規劃的問題時,記憶化搜索與遞推的代碼,在形式上是高度類似的。這是由於它們使用了相同的狀態表示方式和類似的狀態轉移。也正因為如此,一般來説兩種實現的時間複雜度是一樣的。

下面給出的是遞推實現的代碼(為了方便對比,沒有添加滾動數組優化),通過對比可以發現二者在形式上的類似性。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
const int maxn = 1010;
int n, t, w[105], v[105], f[105][1005];

int main() {
  cin >> n >> t;
  for (int i = 1; i <= n; i++) cin >> w[i] >> v[i];
  for (int i = 1; i <= n; i++)
    for (int j = 0; j <= t; j++) {
      f[i][j] = f[i - 1][j];
      if (j >= w[i])
        f[i][j] = max(f[i][j], f[i - 1][j - w[i]] + v[i]);  // 狀態轉移方程
    }
  cout << f[n][t];
  return 0;
}

在求解動態規劃的問題時,記憶化搜索和遞推,都確保了同一狀態至多隻被求解一次。而它們實現這一點的方式則略有不同:遞推通過設置明確的訪問順序來避免重複訪問,記憶化搜索雖然沒有明確規定訪問順序,但通過給已經訪問過的狀態打標記的方式,也達到了同樣的目的。

與遞推相比,記憶化搜索因為不用明確規定訪問順序,在實現難度上有時低於遞推,且能比較方便地處理邊界情況,這是記憶化搜索的一大優勢。但與此同時,記憶化搜索難以使用滾動數組等優化,且由於存在遞歸,運行效率會低於遞推。因此應該視題目選擇更適合的實現方式。

如何寫記憶化搜索

方法一

  1. 把這道題的 dp 狀態和方程寫出來
  2. 根據它們寫出 dfs 函數
  3. 添加記憶化數組

舉例:

\(dp_{i} = \max\{dp_{j}+1\}\quad (1 \leq j < i \land a_{j}<a_{i})\)(最長上升子序列)

轉為

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
int dfs(int i) {
  if (mem[i] != -1) return mem[i];
  int ret = 1;
  for (int j = 1; j < i; j++)
    if (a[j] < a[i]) ret = max(ret, dfs(j) + 1);
  return mem[i] = ret;
}

int main() {
  memset(mem, -1, sizeof(mem));
  // 讀入部分略去
  int ret = 0;
  for (int j = 1; j <= n; j++) {
    ret = max(ret, dfs(j));
  }
  cout << ret << endl;
}
1
2
3
4
5
6
7
8
9
def dfs(i):
    if mem[i] != -1:
        return mem[i]
    ret = 1
    for j in range(1, i):
        if a[j] < a[i]:
            ret = max(ret, dfs(j) + 1)
    mem[i] = ret
    return mem[i]

方法二

  1. 寫出這道題的暴搜程序(最好是 dfs
  2. 將這個 dfs 改成「無需外部變量」的 dfs
  3. 添加記憶化數組

舉例:本文中「採藥」的例子