动态dp学习笔记

时间:2022-01-31 13:59:47

我们经常会遇到一些问题,是一些dp的模型,但是加上了什么待修改强制在线之类的,十分毒瘤,如果能有一个模式化的东西解决这类问题就会非常好。

给定一棵n个点的树,点带点权。

有m次操作,每次操作给定x,y,表示修改点x的权值为y。

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

如果不带修改,那就是一个最简单是树形dp问题。

我们设一个dp[i][0],dp[i][1]表示以i为根的子树

动态dp能够使用的一个前提就是它的转移是线性的,这样我们就可以用矩阵乘法实现快速转移了。

注意:这里的矩阵乘法是广义的,中间运算不一定是乘法,最后也不一定是求和,只要能满足矩阵乘法的性质就可以了。

重链剖分

这也是动态dp比较关键的内容,因为问题在树上,树的每个节点都可能有多个儿子节点,直接算贡献比较麻烦。

所以用重链剖分只保留一个儿子,其他的儿子放在一起统一计算,这样我们就把一个树上问题转化成了序列上的问题。

比如这道题,我们把树轻重链划分完后。

我们把轻子树的答案算完后直接加入状态中,然后答案就变成了一条重链的矩阵连乘积,用线段树维护矩阵的乘积即可。

每次修改时,根据重链剖分,答案包含这个点的位置最多有log个,所以每次就对这些位置修改就好了 。

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 100002
using namespace std;
typedef long long ll;
int tot,head[N],size[N],deep[N],fa[N],son[N],top[N],dp[N][],dfn[N],tag[N],ed[N],a[N],cntt,ls[N<<],rs[N<<],n,m,root;
inline ll rd(){
ll x=;char c=getchar();bool f=;
while(!isdigit(c)){if(c=='-')f=;c=getchar();}
while(isdigit(c)){x=(x<<)+(x<<)+(c^);c=getchar();}
return f?-x:x;
}
struct edge{int n,to;}e[N<<];
inline void add(int u,int v){e[++tot].n=head[u];e[tot].to=v;head[u]=tot;}
struct matrix{
int a[][];
matrix(){memset(a,-0x3f,sizeof(a));}
matrix operator *(const matrix &b)const{
matrix c;
for(int i=;i<;++i)
for(int j=;j<;++j)
for(int k=;k<;++k)
c.a[i][j]=max(c.a[i][j],a[i][k]+b.a[k][j]);
return c;
}
}data[N],tr[N<<];
void dfs1(int u){
size[u]=;
for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa[u]){
int v=e[i].to;deep[v]=deep[u]+;fa[v]=u;
dfs1(v);
size[u]+=size[v];
if(size[v]>size[son[u]])son[u]=v;
}
}
void dfs2(int u){
dfn[u]=++dfn[];tag[dfn[]]=u;
if(!top[u])top[u]=u;
ed[top[u]]=max(ed[top[u]],dfn[u]);
data[u].a[][]=data[u].a[][]=;
data[u].a[][]=a[u];
dp[u][]=a[u];
if(son[u]){
top[son[u]]=top[u],dfs2(son[u]);
dp[u][]+=max(dp[son[u]][],dp[son[u]][]);
dp[u][]+=dp[son[u]][];
}
for(int i=head[u];i;i=e[i].n)if(e[i].to!=fa[u]&&e[i].to!=son[u]){
int v=e[i].to;dfs2(v);
dp[u][]+=max(dp[v][],dp[v][]);
dp[u][]+=dp[v][];
data[u].a[][]+=max(dp[v][],dp[v][]);
data[u].a[][]+=max(dp[v][],dp[v][]);
data[u].a[][]+=dp[v][];
}
}
void build(int &cnt,int l,int r){
if(!cnt)cnt=++cntt;
if(l==r){tr[cnt]=data[tag[l]];return;}
int mid=(l+r)>>;
build(ls[cnt],l,mid);build(rs[cnt],mid+,r);
tr[cnt]=tr[ls[cnt]]*tr[rs[cnt]];
}
void upd(int cnt,int l,int r,int x){
if(l==r){tr[cnt]=data[tag[x]];return;}
int mid=(l+r)>>;
if(mid>=x)upd(ls[cnt],l,mid,x);
else upd(rs[cnt],mid+,r,x);
tr[cnt]=tr[ls[cnt]]*tr[rs[cnt]];
}
matrix query(int cnt,int l,int r,int L,int R){
if(l>=L&&r<=R)return tr[cnt];
int mid=(l+r)>>;
if(mid>=L&&mid<R)return query(ls[cnt],l,mid,L,R)*query(rs[cnt],mid+,r,L,R);
else if(mid>=L)return query(ls[cnt],l,mid,L,R);
else return query(rs[cnt],mid+,r,L,R);
}
void _upd(int u,int vall){
data[u].a[][]+=vall-a[u];
a[u]=vall;
matrix now,pre;
while(u){
pre=query(,,n,dfn[top[u]],ed[top[u]]);
upd(,,n,dfn[u]);
now=query(,,n,dfn[top[u]],ed[top[u]]);
u=fa[top[u]];
data[u].a[][]+=max(now.a[][],now.a[][])-max(pre.a[][],pre.a[][]);
data[u].a[][]=data[u].a[][];
data[u].a[][]+=now.a[][]-pre.a[][];
}
}
int main(){
n=rd();m=rd();
for(int i=;i<=n;++i)a[i]=rd();
int u,v;
for(int i=;i<n;++i){
u=rd();v=rd();
add(u,v);add(v,u);
}
dfs1();dfs2();
build(root,,n);
while(m--){
u=rd();v=rd();
_upd(u,v);
matrix nowans=query(,,n,dfn[],ed[]);
printf("%d\n",max(nowans.a[][],max(nowans.a[][],nowans.a[][])));
}
return ;
}