noip2018 d2t3 保卫王国 解题报告

时间:2022-06-16 19:47:02

保卫王国

电脑卡懒得把题面挪过来了。


朴素

\[dp_{i,0}=\sum dp_{s,1}\\
dp_{i,1}=\sum \min(dp_{s,0},dp_{s,1})+p_i
\]

然后直接动态dp就行了

我发现lct是最好写的,反正比树剖好写,还比她快

没倍增快,但是看起来倍增挺难写的...


Code:

#include <cstdio>
#include <algorithm>
#define ll long long
const ll inf=1ll<<45;
const int N=1e5+10;
int ch[N][2],par[N];
struct matrix{ll a,b,c,d;}dat[N],sum[N],ret;
int head[N],to[N<<1],Next[N<<1],cnt;
ll dp[N][2],p[N];
int n,m;char str[233];
using std::min;
void add(int u,int v)
{
to[++cnt]=v,Next[cnt]=head[u],head[u]=cnt;
}
void dfs(int now)
{
dp[now][1]=p[now];
for(int v,i=head[now];i;i=Next[i])
if((v=to[i])!=par[now])
{
par[v]=now,dfs(v);
dp[now][0]+=dp[v][1];
dp[now][1]+=min(dp[v][0],dp[v][1]);
}
dat[now]=sum[now]=(matrix){inf,dp[now][0],dp[now][1],dp[now][1]};
}
matrix operator ^(matrix a,matrix b)
{
ret.a=min(a.a+b.a,a.b+b.c);
ret.b=min(a.a+b.b,a.b+b.d);
ret.c=min(a.c+b.a,a.d+b.c);
ret.d=min(a.c+b.b,a.d+b.d);
return ret;
}
#define ls ch[now][0]
#define rs ch[now][1]
#define fa par[now]
void updata(int now)
{
sum[now]=ls?sum[ls]^dat[now]:dat[now];
sum[now]=rs?sum[now]^sum[rs]:sum[now];
}
bool isroot(int now){return ch[fa][0]==now||ch[fa][1]==now;}
int identity(int now){return ch[fa][1]==now;}
void connect(int f,int now,int typ){ch[fa=f][typ]=now;}
void Rotate(int now)
{
int p=fa,typ=identity(now);
connect(p,ch[now][typ^1],typ);
if(isroot(p)) connect(par[p],now,identity(p));
else fa=par[p];
connect(now,p,typ^1);
updata(p),updata(now);
}
void splay(int now)
{
for(;isroot(now);Rotate(now))
if(isroot(fa))
Rotate(identity(now)^identity(fa)?now:fa);
}
void access(int now)
{
for(int las=0;now;las=now,now=fa)
{
splay(now);
if(las)
{
dat[now].b-=sum[las].d;
dat[now].c-=min(sum[las].b,sum[las].d);
}
if(rs)
{
dat[now].b+=sum[rs].d;
dat[now].c+=min(sum[rs].b,sum[rs].d);
}
dat[now].d=dat[now].c;
rs=las;
updata(now);
}
}
void modify(int now,ll w)
{
access(now),splay(now);
dat[now].c+=w-p[now],p[now]=w;
dat[now].d=dat[now].c;
updata(now);
}
int main()
{
scanf("%d%d%s",&n,&m,str);
for(int i=1;i<=n;i++) scanf("%lld",p+i);
for(int u,v,i=1;i<n;i++) scanf("%d%d",&u,&v),add(u,v),add(v,u);
dfs(1);
for(int a,x,b,y,i=1;i<=m;i++)
{
scanf("%d%d%d%d",&a,&x,&b,&y);
ll ans=0,t1=p[a],t2=p[b];
if(x&&y)
{
modify(a,-inf),modify(b,-inf);
ans=(inf<<1)+t1+t2;
}
else if(x&&!y)
{
modify(a,-inf),modify(b,inf<<1);
ans=inf+t1;
}
else if(!x&&y)
{
modify(a,inf<<1),modify(b,-inf);
ans=inf+t2;
}
else
modify(a,inf<<1),modify(b,inf<<1);
ans+=min(sum[b].b,sum[b].d);
if(ans>=inf) puts("-1");
else printf("%lld\n",ans);
modify(a,t1),modify(b,t2);
}
return 0;
}

2019.1.4