跳转至

動態 DP

前置知識:矩陣樹鏈剖分

動態 DP 問題是貓錕在 WC2018 講的黑科技,一般用來解決樹上的帶有點權(邊權)修改操作的 DP 問題。

例子

以這道模板題為例子講解一下動態 DP 的過程。

例題 洛谷 P4719【模板】動態 DP

給定一棵 \(n\) 個點的樹,點帶點權。有 \(m\) 次操作,每次操作給定 \(x,y\) 表示修改點 \(x\) 的權值為 \(y\)。你需要在每次操作之後求出這棵樹的最大權獨立集的權值大小。

廣義矩陣乘法

定義廣義矩陣乘法 \(A\times B=C\) 為:

\[ C_{i,j}=\max_{k=1}^{n}(A_{i,k}+B_{k,j}) \]

相當於將普通的矩陣乘法中的乘變為加,加變為 \(\max\) 操作。

同時廣義矩陣乘法滿足結合律,所以可以使用矩陣快速冪。

不帶修改操作

\(f_{i,0}\) 表示不選擇 \(i\) 的最大答案,\(f_{i,1}\) 表示選擇 \(i\) 的最大答案。

則有 DP 方程:

\[ \begin{cases}f_{i,0}=\sum_{son}\max(f_{son,0},f_{son,1})\\f_{i,1}=w_i+\sum_{son}f_{son,0}\end{cases} \]

答案就是 \(\max(f_{root,0},f_{root,1})\).

帶修改操作

首先將這棵樹進行樹鏈剖分,假設有這樣一條重鏈:

\(g_{i,0}\) 表示不選擇 \(i\) 且只允許選擇 \(i\) 的輕兒子所在子樹的最大答案,\(g_{i,1}\) 表示不考慮 \(son_i\) 的情況下選擇 \(i\) 的最大答案,\(son_i\) 表示 \(i\) 的重兒子。

假設我們已知 \(g_{i,0/1}\) 那麼有 DP 方程:

\[ \begin{cases}f_{i,0}=g_{i,0}+\max(f_{son_i,0},f_{son_i,1})\\f_{i,1}=g_{i,1}+f_{son_i,0}\end{cases} \]

答案是 \(\max(f_{root,0},f_{root,1})\).

可以構造出矩陣:

\[ \begin{bmatrix} g_{i,0} & g_{i,0}\\ g_{i,1} & -\infty \end{bmatrix}\times \begin{bmatrix} f_{son_i,0}\\f_{son_i,1} \end{bmatrix}= \begin{bmatrix} f_{i,0}\\f_{i,1} \end{bmatrix} \]

注意,我們這裏使用的是廣義乘法規則。

可以發現,修改操作時只需要修改 \(g_{i,1}\) 和每條往上的重鏈即可。

具體思路

  1. DFS 預處理求出 \(f_{i,0/1}\)\(g_{i,0/1}\).

  2. 對這棵樹進行樹鏈剖分(注意,因為我們對一個點進行詢問需要計算從該點到該點所在的重鏈末尾的區間矩陣乘,所以對於每一個點記錄 \(End_i\) 表示 \(i\) 所在的重鏈末尾節點編號),每一條重鏈建立線段樹,線段樹維護 \(g\) 矩陣和 \(g\) 矩陣區間乘積。

  3. 修改時首先修改 \(g_{i,1}\) 和線段樹中 \(i\) 節點的矩陣,計算 \(top_i\) 矩陣的變化量,修改到 \(fa_{top_i}\) 矩陣。

  4. 查詢時就是 1 到其所在的重鏈末尾的區間乘,最後取一個 \(\max\) 即可。

代碼實現
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int maxn = 500010;
const int INF = 0x3f3f3f3f;

int Begin[maxn], Next[maxn], To[maxn], e, n, m;
int sz[maxn], son[maxn], top[maxn], fa[maxn], dis[maxn], p[maxn], id[maxn],
    End[maxn];
// p[i]表示i树剖后的编号,id[p[i]] = i
int cnt, tot, a[maxn], f[maxn][2];

struct matrix {
  int g[2][2];

  matrix() { memset(g, 0, sizeof(g)); }

  matrix operator*(const matrix &b) const  // 重载矩阵乘
  {
    matrix c;
    for (int i = 0; i <= 1; i++)
      for (int j = 0; j <= 1; j++)
        for (int k = 0; k <= 1; k++)
          c.g[i][j] = max(c.g[i][j], g[i][k] + b.g[k][j]);
    return c;
  }
} Tree[maxn], g[maxn];  // Tree[]是建出来的线段树,g[]是维护的每个点的矩阵

void PushUp(int root) { Tree[root] = Tree[root << 1] * Tree[root << 1 | 1]; }

void Build(int root, int l, int r) {
  if (l == r) {
    Tree[root] = g[id[l]];
    return;
  }
  int Mid = l + r >> 1;
  Build(root << 1, l, Mid);
  Build(root << 1 | 1, Mid + 1, r);
  PushUp(root);
}

matrix Query(int root, int l, int r, int L, int R) {
  if (L <= l && r <= R) return Tree[root];
  int Mid = l + r >> 1;
  if (R <= Mid) return Query(root << 1, l, Mid, L, R);
  if (Mid < L) return Query(root << 1 | 1, Mid + 1, r, L, R);
  return Query(root << 1, l, Mid, L, R) *
         Query(root << 1 | 1, Mid + 1, r, L, R);
  // 注意查询操作的书写
}

void Modify(int root, int l, int r, int pos) {
  if (l == r) {
    Tree[root] = g[id[l]];
    return;
  }
  int Mid = l + r >> 1;
  if (pos <= Mid)
    Modify(root << 1, l, Mid, pos);
  else
    Modify(root << 1 | 1, Mid + 1, r, pos);
  PushUp(root);
}

void Update(int x, int val) {
  g[x].g[1][0] += val - a[x];
  a[x] = val;
  // 首先修改x的g矩阵
  while (x) {
    matrix last = Query(1, 1, n, p[top[x]], End[top[x]]);
    // 查询top[x]的原本g矩阵
    Modify(1, 1, n,
           p[x]);  // 进行修改(x点的g矩阵已经进行修改但线段树上的未进行修改)
    matrix now = Query(1, 1, n, p[top[x]], End[top[x]]);
    // 查询top[x]的新g矩阵
    x = fa[top[x]];
    g[x].g[0][0] +=
        max(now.g[0][0], now.g[1][0]) - max(last.g[0][0], last.g[1][0]);
    g[x].g[0][1] = g[x].g[0][0];
    g[x].g[1][0] += now.g[0][0] - last.g[0][0];
    // 根据变化量修改fa[top[x]]的g矩阵
  }
}

void add(int u, int v) {
  To[++e] = v;
  Next[e] = Begin[u];
  Begin[u] = e;
}

void DFS1(int u) {
  sz[u] = 1;
  int Max = 0;
  f[u][1] = a[u];
  for (int i = Begin[u]; i; i = Next[i]) {
    int v = To[i];
    if (v == fa[u]) continue;
    dis[v] = dis[u] + 1;
    fa[v] = u;
    DFS1(v);
    sz[u] += sz[v];
    if (sz[v] > Max) {
      Max = sz[v];
      son[u] = v;
    }
    f[u][1] += f[v][0];
    f[u][0] += max(f[v][0], f[v][1]);
    // DFS1过程中同时求出f[i][0/1]
  }
}

void DFS2(int u, int t) {
  top[u] = t;
  p[u] = ++cnt;
  id[cnt] = u;
  End[t] = cnt;
  g[u].g[1][0] = a[u];
  g[u].g[1][1] = -INF;
  if (!son[u]) return;
  DFS2(son[u], t);
  for (int i = Begin[u]; i; i = Next[i]) {
    int v = To[i];
    if (v == fa[u] || v == son[u]) continue;
    DFS2(v, v);
    g[u].g[0][0] += max(f[v][0], f[v][1]);
    g[u].g[1][0] += f[v][0];
    // g矩阵根据f[i][0/1]求出
  }
  g[u].g[0][1] = g[u].g[0][0];
}

int main() {
  scanf("%d%d", &n, &m);
  for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
  for (int i = 1; i <= n - 1; i++) {
    int u, v;
    scanf("%d%d", &u, &v);
    add(u, v);
    add(v, u);
  }
  dis[1] = 1;
  DFS1(1);
  DFS2(1, 1);
  Build(1, 1, n);
  for (int i = 1; i <= m; i++) {
    int x, val;
    scanf("%d%d", &x, &val);
    Update(x, val);
    matrix ans = Query(1, 1, n, 1, End[1]);  // 查询1所在重链的矩阵乘
    printf("%d\n", max(ans.g[0][0], ans.g[1][0]));
  }
  return 0;
}

習題