树形DP 学习笔记

时间:2022-01-31 13:59:11

树形DP学习笔记

ps: 本文内容与蓝书一致

树的重心

  • 概念: 一颗树中的一个节点其最大子树的节点树最小
  • 解法:对与每个节点求他儿子的\(size\) ,上方子树的节点个数为\(n-size_u\) ,求对于每个节点子树的最大值,找出最小的那个就好了;

(我觉得就不需要code了)


树的直径

  • 概念:一颗带权树的最长路径
  • 解法:维护一个节点到叶子节点的最大距离\(d1[i]\)和次大距离\(d2[i]\) ,最大距离就是$max {d1[i]+d2[i] } $

code

#include<iostream>
#include<cstdio>
using namespace std;
const int N=1e4+5;
int n;
struct pp
{
int to,next;
}w[2*N];
int head[N],cnt;
int d1[N],d2[N];
int ans;
void add(int x,int y)
{
cnt++;
w[cnt].next=head[x];
w[cnt].to=y;
head[x]=cnt;
}
void dfs(int x,int fa)
{
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t!=fa)
{
dfs(t,x);
if(d1[t]+1>d1[x])
{
d2[x]=d1[x];
d1[x]=d1[t]+1;
}
else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
}
}
return ;
}
void find_ans(int x,int fa)
{
ans=max(ans,d1[x]+d2[x]);
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t!=fa) find_ans(t,x);
}
return;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("diam.in","r",stdin);
freopen("diam.out","w",stdout);
#endif
scanf("%d",&n);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(1,0);
find_ans(1,0);
printf("%d",ans);
return 0;
}

例题

P4480 逃学的小孩

  • 大概思路:求出树的直径以及其左右端点,再设\(d[i]\)为树上节点\(i\)到左右端点距离更小的那个,然后求出\(max \{d[i]\}\),然后以这个值加上直径就是\(ans\) ;

code

#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
using namespace std;
const int N=2e5+5;
struct pp
{
int next,to;
ll qu;
}w[N*2];
int head[N],cnt;
int n,m;
bool v[N];
ll d1[N],d2[N],dl[N],dr[N];
int f1[N],f2[N];
int r,l;
ll ans,mans;
void add(int x,int y,int z)
{
w[++cnt].next=head[x];
w[cnt].qu=z;
w[cnt].to=y;
head[x]=cnt;
}
int read()
{
int f=1;
char ch;
while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
int res=ch-'0';
while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
return res*f;
}
void dfs1(int x)
{
if(v[x]) return ;
v[x]=1;
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(!v[t])
{
dfs1(t);
if(d1[t]+w[i].qu>d1[x])
{
f2[x]=f1[x];
f1[x]=f1[t];
d2[x]=d1[x];
d1[x]=d1[t]+w[i].qu;
}
else if(d1[t]+w[i].qu>d2[x]) d2[x]=d1[t]+w[i].qu,f2[x]=f1[t];
} }
return;
}
void find_ans(int x)
{
if(v[x]) return;
v[x]=1;
if(ans<d1[x]+d2[x])
{
ans=d1[x]+d2[x];
l=f1[x];
r=f2[x];
}
for(int i=head[x];i;i=w[i].next) find_ans(w[i].to);
}
void dfs2(int x)
{
if(v[x]) return;
v[x]=1;
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(!v[t])
{
dl[t]=dl[x]+w[i].qu;
dfs2(t);
}
}
return;
}
void dfs3(int x)
{
if(v[x])return;
v[x]=1; for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(!v[t])
{
dr[t]=dr[x]+w[i].qu;
dfs3(t);
}
}
return;
}
void dfs_ans(int x)
{
if(v[x]) return;
v[x]=1;
mans=max(mans,min(dl[x],dr[x]));
for(int i=head[x];i;i=w[i].next) dfs_ans(w[i].to);
return;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("Chris.in","r",stdin);
freopen("Chris.out","w",stdout);
#endif
n=read();
m=read();
for(int i=1;i<=m;i++)
{
int x,y,z;
x=read(),y=read(),z=read();
add(x,y,z);
add(y,x,z);
}
for(int i=1;i<=n;i++) f1[i]=i;
dfs1(1);
memset(v,0,sizeof(v));
find_ans(1);
memset(v,0,sizeof(v));
dfs2(l);
memset(v,0,sizeof(v));
dfs3(r);
memset(v,0,sizeof(v));
dfs_ans(1);
printf("%lld",ans+mans);
return 0;
}

树的中心

  • 概念:给出一颗带权树,求一个节点,使得此节点到树中其他节点的最远距离最小;

  • 解法:如果是一颗没有负边权的树,那直接找到直径的中点就好;

    但是这里我们考虑有负边权的情况:

    有两种情况:

    1. 从\(u\)点向上的最长路径,设为\(up[u]\);
    2. 从\(u\)点向下,即\(u\)到叶节点的最远距离,设为\(d1[u]\)(次远设为\(d2[u]\));

    \(d1[u]\)和\(d2[u]\)都会求,问题是\(up[u]\)该怎么求?

    还是分类讨论,设\(u\)的父亲为\(x\),\(d1[x]\)来自于子节点\(v\);那对于\(u\):

    1. 如果\(u!=v\),那么\(up[u]=max\{d1[x],up[x]\}+dis[x][t]\);
    2. 如果\(u==v\),那么\(up[u]=max\{d2[x],up[x]\}+dis[x][t]\),这也是为什么要维护\(d2[x]\)的原因;

code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=1e5+5;
struct pp
{
int next,to;
}w[2*N];
int n,k;
int head[N],cnt;
int d1[N],d2[N],pre[N],u[N];
int root,far;
int read()
{
int f=1;
char ch;
while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
int res=ch-'0';
while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
return res*f;
}
void add(int x,int y)
{
cnt++;
w[cnt].next=head[x];
w[cnt].to=y;
head[x]=cnt;
}
bool cmp(int x,int y) {return x>y;}
void dfs1(int x,int fa)
{
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t!=fa)
{
dfs1(t,x);
if(d1[t]+1>d1[x])
{
pre[x]=t;
d2[x]=d1[x];
d1[x]=d1[t]+1;
}
else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
}
}
return;
}
void dfs2(int x,int fa)
{
int minx=min(u[x],d1[x]);
if(far<minx)
{
root=x;
far=minx;
}
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if (t!=fa)
{
if(pre[x]!=t) u[t]=max(d1[x],u[x])+1;
else u[t]=max(d2[x],u[x])+1;
dfs2(t,x);
}
}
return ;
}
int main()
{
n=read(),k=read();
for(int i=1;i<n;i++)
{
int x,y;
x=read(),y=read();
add(x,y);
add(y,x);
}
dfs1(1,0);
dfs2(1,0);
printf("%d",root);
return 0;
}

例题

P5536核心城市

  • 思路:显然其中一定会有一个城市为这颗树的中心;那找出这个中心,把这颗无根树变为以它为根的有根树;再求出除根节点以外的每个节点所能到达的最大深度\(deepfar[i]\),这就是这个节点最远所能到达的距离;然后\(sort\)一下\(deepfar[]\),答案就是\(deepfar[k+1]+1\);

code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N=1e5+5;
struct pp
{
int next,to;
}w[2*N];
int n,k;
int head[N],cnt;
int d1[N],d2[N],pre[N],u[N];
int fardeep[N];
int root,far;
int read()
{
int f=1;
char ch;
while((ch=getchar())<'0'||ch>'9') if(ch=='-') f=-1;
int res=ch-'0';
while((ch=getchar())>='0'&&ch<='9') res=res*10+ch-'0';
return res*f;
}
void add(int x,int y)
{
cnt++;
w[cnt].next=head[x];
w[cnt].to=y;
head[x]=cnt;
}
bool cmp(int x,int y) {return x>y;}
void dfs1(int x,int fa)
{
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t!=fa)
{
dfs1(t,x);
if(d1[t]+1>d1[x])
{
pre[x]=t;
d2[x]=d1[x];
d1[x]=d1[t]+1;
}
else if(d1[t]+1>d2[x]) d2[x]=d1[t]+1;
}
}
return;
}
void dfs2(int x,int fa)
{
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if (t!=fa)
{
if(pre[x]!=t) u[t]=max(d1[x],u[x])+1;
else u[t]=max(d2[x],u[x])+1;
dfs2(t,x);
}
}
return ;
}
void dfs3(int x,int fa)
{
int minx=min(u[x],d1[x]);
if(far<minx)
{
root=x;
far=minx;
}
for(int i=head[x];i;i=w[i].next) if(w[i].to!=fa) dfs3(w[i].to,x);
return;
}
void dfs4(int x,int fa)
{
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t!=fa)
{
dfs4(w[i].to,x);
fardeep[x]=max(fardeep[x],fardeep[t]+1);
}
}
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("XR-3.in","r",stdin);
freopen("XR-3.out","w",stdout);
#endif
n=read(),k=read();
for(int i=1;i<n;i++)
{
int x,y;
x=read(),y=read();
add(x,y);
add(y,x);
}
dfs1(1,0);
dfs2(1,0);
dfs3(1,0);
dfs4(root,0);
sort(fardeep+1,fardeep+1+n,cmp);
printf("%d",fardeep[k+1]+1);
return 0;
}

上面都是有关树的一些经典题型,下面才是今天的主角——树型DP


背包类树型DP

(我觉得把,其实左右子树类树型DP可以归为这一类)

例题

选课

书上的是时间复杂度为\(n^3\)的算法,这里介绍一个优化,可以讲其降为\(n^2\);

没有优化前,DP方程为:

\[ dp[u][j]=max\{dp[u][j],dp[u][j-k-1]+dp[v][k]\}+v[u]
\]

这样对于每个节点都要\(n^2\)暴力枚举\(j\)和\(k\);

经过优化,我们的DP方程就变为了:

\[ \begin{cases}
dp[v][j]=dp[u][j]+v[v]\\
dp[u][j]=max\{dp[u][j],dp[v][j-1]\}
\end{cases}
\]

这也是再泛化物品优化下,树型背包的基本DP方程;这样我们只需要\(O(n)\)枚举\(j\)就好了;

code

#include<iostream>
#include<algorithm>
#include<queue>
#include<cstdio>
#include<cstring>
using namespace std; int n,m;
struct edge
{
int next,to;
}e[1000];
int rt,head[1000],tot,val[1000],dp[1000][1000];
void add(int x,int y)
{
e[++tot].next=head[x];
head[x]=tot;
e[tot].to=y;
}
void dfs(int u,int t)
{
if (t<=0) return ;
for (int i=head[u]; i; i=e[i].next)
{
int v = e[i].to;
for (int j=0; j<t; ++j) //这里j从o开始到tot-1,因为v的子树可以选择的节点是u的子树的节点数减一;
dp[v][j] = dp[u][j]+val[v];
dfs(v,t-1);
for (int j=1; j<=t; ++j)
dp[u][j] = max(dp[u][j],dp[v][j-1]);//u必须选,所以u选择j个点v只能选择j-1个点;
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
int a;
scanf("%d%d",&a,&val[i]);
if(a)
add(a,i);
if(!a)add(0,i);
}
dfs(0,m);
printf("%d",dp[0][m]);
}

选择类树型DP

基本DP方程:

\[v\in{son(u)}
\begin{cases}
f[u][0]=\sum f[v][1] \\
f[u][1]=min\{f[v][1],f[v][0]\}+1
\end{cases}
\]

例题

P2016战略游戏

直接套DP方程就好了;

code

#include<iostream>
#include<cstdio>
using namespace std;
int n;
int dp[1605][2];
struct pp
{
int next,to;
}w[1600<<1];
int head[1600],cnt;
void add(int x,int y)
{
cnt++;
w[cnt].to=y;
w[cnt].next=head[x];
head[x]=cnt;
}
void dfs(int x,int fa)
{
dp[x][1]=1;
for(int i=head[x];i;i=w[i].next)
{
int t=w[i].to;
if(t==fa) continue;
dfs(t,x);
dp[x][0]+=dp[t][1];
dp[x][1]+=min(dp[t][0],dp[t][1]);
}
return;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
int a,k;
scanf("%d%d",&a,&k);
for(int i=1;i<=k;i++)
{
int b;
scanf("%d",&b);
add(a,b);
add(b,a);
}
}
dfs(0,0);
printf("%d",min(dp[0][1],dp[0][0]));
return 0;
}

普通树型DP

这种树型DP更加灵活,就不像前两种有基本固定的DP方程,所以还是直接来几道例题;(滑稽

例题

LOJ #10157. 皇宫看守

乍一看题,啊哈,模板选择树型DP,开开心心打个代码,恭喜你0分;

仔细一看这道题其实不是什么没有上司的舞会,而是一道覆盖DP题,区别在哪呢?

这道题一条边两端至少要有一个点,可以有两个,而没有上司我舞会是一条边两端至多有一个点,可以没有;

那好,这样的话一个节点u的最少经费就不能像选择DP一样单纯的由儿子选不选的而转移过来,因为他们本来互不冲突,而是必须被覆盖到(这里每个节点的覆盖半径为1),这样对于一个节点u的最少经费就可以由覆盖它的节点转移过来,这样的话就需要考虑三种情况:

首先设\(dp[u][0]\)表示被节点\(u\)被父亲覆盖且\(u\)不选,\(dp[u][1]\)表示被自己的子节点覆盖且\(u\)不选,\(dp[u][2]\)表示被自己覆盖;

所以有状态转移方程:

  • 对于\(dp[u][0]\),因为\(u\)不选,所以对于\(u\)的子节点\(v\),要么被\(son(v)\)所覆盖,要么被\(v\)自己覆盖:
\[dp[u][0]=\sum min\{dp[v][1],dp[v][2]\} +a[f[u]];
\]
  • 对于\(dp[u][1]\),要保证\(u\)必须被一个子节点所覆盖到,还要保证\(u\)的子节点\(v\)在不被父亲覆盖的前提下被覆盖到,那显然\(dp[u][1]\),是由\(dp[v][1]\)和\(dp[v][2]\)转移过来的,但是如何保证\(dp[u][1]\)的转移中一定包含\(dp[v][2]\)呢?

    这时候有个巧妙的办法,设个参数:

    \[d=min\{d,dp[v][2]-min\{dp[v][1],dp[v][2]\}\}
    \]

    \(d\)的初始值为\(0x7fffffff\);

    这样对于\(dp[u][1]\)就有状态转移方程:

    \[dp[u][1]=\sum min\{dp[v][1],dp[v][2]\}+d
    \]
  • 对于\(dp[u][2]\),那很显然它可以由子节点任意三种状态转移过来,但是对于\(dp[v][0]\),它已经加过一遍\(a[u]\),而对于\(dp[u][2]\),只能且必须加一遍\(a[u]\),那怎么办呢?单独特判由\(dp[v][0]\)转移过来的情况,控制\(a[u]\)只加一遍?显然是可以的,但是太麻烦了,那么另外考虑,这里可以看到\(dp[v][0]\)只会往\(dp[u][2]\)上转移,那么可以根据\(dp[u][2]\)需求对\(dp[v][0]\)状态转移方程改一改:

    \[dp[u][0]=\sum min\{dp[v][1],dp[v][2]\}
    \]

    (这里的\(u\)是对于\(v\)来说的)

    感性理解一下就是如果\(dp[u][2]\)不由\(dp[v][0]\)转移过来那要\(dp[v][0]\)也没有什么用,那由\(dp[v][0]\)转移过来,那在\(dp[u][2]\)这加一遍\(a[u]\)就够了,因为\(dp[u][2]\)已经保证了\(u\)被选,所以不需要\(dp[v][0]\)再保证一遍;

    这样对于\(dp[u][2]\),就有状态转移方程:

    \[dp[u][2]=\sum min\{dp[v][1],dp[v][2],dp[v][0]\} +a[u]
    \]

总结下来就有三个状态转移方程:

\[\begin{cases}
dp[u][0]=\sum min\{dp[v][1],dp[v][2]\};\\

dp[u][1]=\sum min\{dp[v][1],dp[v][2]\}+d ;(d=min\{d,dp[v][2]-min\{dp[v][1],dp[v][2]\}\})\\

dp[u][2]=\sum min\{dp[v][1],dp[v][2],dp[v][0]\} +a[u]
\end{cases}
\]

(所以,显然书上的状态转移方程是错的)

不难发现,修改后的\(dp[v][0]\)一定小于等于\(dp[v][1]\);所以写代码的时候我顺手把\(dp[u][2]\)的转移方程改成了:

\[dp[u][2]=\sum min\{dp[v][2],dp[v][0]\} +a[u]
\]

虽然题目早已经解决了,但我还是想再深究一下;这个方程啥意思?

以我的感性理解就是\(v\)既然已经一定会被它爹\(u\)覆盖到,那就可以不需要保证\(v\)一定被它的儿子所覆盖,修改后的\(dp[v][0]\)刚好就是这种情况;

(好了,bb了这么多废话,就一点有用的东西,直接上代码吧)

code

#include <iostream>
#include <cstdio>
using namespace std;
const int N = 1500 + 5;
int dp[N][3];
int v[N], n, root;
struct pp {
int next, to;
} w[N];
int head[N], cnt, du[N];
void add(int x, int y) {
cnt++;
w[cnt].next = head[x];
w[cnt].to = y;
head[x] = cnt;
}
void dfs(int x) {
int d = 0x7fffffff;
for (int i = head[x]; i; i = w[i].next) {
int t = w[i].to;
dfs(t);
dp[x][0] += min(dp[t][1], dp[t][2]);
dp[x][1] += min(dp[t][1], dp[t][2]);
d = min(d, dp[t][2] - min(dp[t][1], dp[t][2]));
dp[x][2] += min(dp[t][2], dp[t][0]);
}
dp[x][1] += d;
dp[x][2] += v[x];
}
int main() {
#ifndef ONLINE_JUDGE
freopen("guard.in", "r", stdin);
freopen("guard.out", "w", stdout);
#endif
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
int x, m;
scanf("%d", &x);
scanf("%d", &v[x]);
scanf("%d", &m);
for (int j = 1; j <= m; j++) {
int y;
scanf("%d", &y);
add(x, y);
du[y]++;
}
}
for (int i = 1; i <= n; i++)
if (!du[i])
root = i;
dfs(root);
printf("%d", min(dp[root][1], dp[root][2]));
return 0;
}

好了,差不多就结束了,虽然写这个一点耗时,但对于我这个蒟蒻来说加深了对于DP的理解,收获也不小,也不算浪费时间了吧(逃);