树形 DP

简介 #

树形 DP 以树形结构为研究对象. 通常设 $f[u]$ 为树中 $u$ 号节点的值,利用树形关系推出其它节点的值. DP 过程多为 记忆化搜索.

例 1 #

给定一棵 $n$ 个点,$m$ 条边的树,顶点编号为 $1\sim n$,且以 $1$ 号节点为根. 以 $i$ 号节点为根的子树有几个节点?

$f[i]$:以 $i$ 号节点为根的子树的节点数.

$Son[i]$:$i$ 号节点的子节点集合.

$$f[i]=1+\sum_{v\in Son[i]}f[v]$$

计算顺序为 $f[$子节点$]→f[$父节点$]$. 使用记忆化搜索.

vector<int> son[]; // son[u] : 节点 u 的子节点集合

void dfs(int u) { // 求以 u 为根的子树中节点个数
    f[u] = 1;
    for (int i = 0; i < son[u].size(); i ++) {
        int v = son[u][i]; // 节点 u 的第 i 个子节点
        dfs(v);
        f[u] += f[v];
    }
}

例 2 #

公司有 $n$ 个人,编号为 $1\cdots n$,其中 $1$ 号员工是 boss. 现要举⾏⼀场晚会,如果邀请了某个⼈,那么他的上司不会来(他上司的上司,上司的上司的上司 $\cdots$ 都可以来).

每个⼈都有⼀个欢乐值,给出公司所有人的上下级关系,求⼀个邀请⽅案,使欢乐值的和最⼤.

$f[i,j]$:从员工 $i$ 和所有下属中邀请部分人参会的最大欢乐值. 当 $j=0$ 时 $i$ 号员工不参会,$j=1$ 时参会.

$Son[i]$:员工 $i$ 的下属集合.

  • 若 $i$ 号员工参会,他的直接下属都不来:

$$f[i,1]=H_i+\sum_{v\in Son(i)}f[v,0]$$

  • 若 $i$ 号员工参会,他的直接下属爱来不来,于是取最大值:

$$f[i,0]=\sum_{v\in Son(i)}\max\{f[v,0],f[v,1]\}$$

时间复杂度为 $O(n)$,最终答案为 $\max\{f[1,0],f[1,1]\}$.

vector<int> son[]; // son[u] : 员工 u 的下属集合

void dfs(int u) { // 求出 u 号员工对应的 f[u][0] 和 f[u][1]
    f[u][1] = h[u];
    for (int i = 0; i < son[u].size(); i ++) {
        int v = son[u][i]; // 员工 u 的第 i 个下属
        dfs(v);
        f[u][1] += f[v][0];
        f[u][0] += max(f[v][0], f[v][1]);
    }
}

树形 DP + 背包 DP #

处理某些问题时,需要结合树形 DP 和背包 DP 的思想.

现有 $n$ 门课程,第 $i$ 门课程的学分为 $s_i$,每门课程有 $0$ 或 $1$ 门先修课.有先修课的课程需要先学完先修课,才能学习该课程.求学习 $m$ 门课程能获得的最多学分.

将每门课程看作树中的节点,$a→b$ 代表 $a$ 比 $b$ 先修:

flowchart
2-->1
2-->4
2-->7
    7-->5
    7-->6
3

为了方便解决问题,新增 $0$ 号节点,使其指向所有无先修课的课程:

flowchart
0-->2
0-->3
    2-->1
    2-->4
    2-->7
        7-->5
        7-->6
    3

$f[u,j]$ 表示以 $u$ 为根节点,选 $j$ 个节点,获得的最大学分.

$dfs(u)$ 的功能是算出以 $u$ 为根节点时,分别选 $0\sim m$ 个节点时能获得的最大学分.

执行 $dfs(u)$ 时,枚举 $u$ 的子节点 $v$,在内层循环枚举选取的节点数 $j=m→1$:

  • 将 $j$ 个节点分成两组,一组 $k$ 个,另一组 $j-k$ 个;

  • 将第一组 $k$ 个节点放在以 $v$ 为根节点的子树中,最大学分为 $f[v][k]$;

  • 将第二组 $j-k$ 个节点放在以 $u$ 为根节点的子树中,但不放在以 $v$ 为根节点的子树中,最大学分为 $f[u][j-k]$.

$$f[u,j]=\max_{v\in Son(u)}{f[u,j],f[u,j-k]+f[v,k]}$$

初始条件 $f[i,1]=s[i]$ 边界条件 $f[0,m]$
时间复杂度 $O(n^3)$

在主程序中执行 $dfs(0)$ 后输出 $f[0,m]$ 即可.

void dfs(int u) {
    f[u][1] = s[u];
    for(int i = 0; i < son[u].size(); i ++) {
        int v = son[u][i];
        dfs(v);
        for(int j = m; j >= 1; j --)
            for(int k = j - 1; k > 0; k --)
                f[u][j] = max(f[u][j], f[v][k] + f[u][j - k]);
    }
}