洛谷 4178 Tree——点分治

时间:2024-01-14 16:13:32

题目:https://www.luogu.org/problemnew/show/P4178

点分治。如果把每次的 dis 和 K-dis 都离散化,用树状数组找,是O(n*logn*logn),会T7个点。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=4e4+;
int n,hd[N],xnt,to[N<<],nxt[N<<],w[N<<],f[N<<],siz[N],ans,mn,rt;
ll dis[N],tis[N],tp[N<<],tnt,K;
bool vis[N],sj[N];
void add(int x,int y,ll z)
{
to[++xnt]=y;nxt[xnt]=hd[x];w[xnt]=z;hd[x]=xnt;
to[++xnt]=x;nxt[xnt]=hd[y];w[xnt]=z;hd[y]=xnt;
}
void getrt(int cr,int fa,int s)
{
siz[cr]=;int mx=;
for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)
{
getrt(v,cr,s);siz[cr]+=siz[v];mx=max(mx,siz[v]);
}
mx=max(mx,s-siz[cr]);
if(mx<mn)mn=mx,rt=cr;
}
void add(int x){for(;x<=tnt;x+=(x&-x))f[x]++;}
int query(int x){int ret=;for(;x;x-=(x&-x))ret+=f[x];return ret;}
void dfs(int cr,int fa,ll lj)
{
dis[cr]=lj;sj[cr]=;
for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)
dfs(v,cr,lj+w[i]);
}
int calc(int cr,ll w)
{
memset(sj,,sizeof sj);tnt=;dfs(cr,,w);
for(int i=;i<=n;i++) if(sj[i]&&dis[i]<=K)
{
tis[i]=K-dis[i];tp[++tnt]=dis[i];tp[++tnt]=tis[i];
// printf("dis[%d]=%lld tis[%d]=%lld\n",i,dis[i],i,tis[i]);
}
sort(tp+,tp+tnt+);tnt=unique(tp+,tp+tnt+)-tp-;
int ret=;
for(int i=;i<=n;i++) if(sj[i]&&dis[i]<=K)
{
dis[i]=lower_bound(tp+,tp+tnt+,dis[i])-tp;
tis[i]=lower_bound(tp+,tp+tnt+,tis[i])-tp;
// printf("dis[%d]=%lld tis[%d]=%lld\n",i,dis[i],i,tis[i]);
ret+=query(tis[i]);add(dis[i]);
}
memset(f,,sizeof f);
return ret;
}
void solve(int cr,int s)
{
// printf("rt=%d\n",cr);
vis[cr]=;
ans+=calc(cr,);
// printf("cr=%d ans=%d\n",cr,ans);
for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]])
{
ans-=calc(v,w[i]);
int ts=(siz[cr]>siz[v]?siz[v]:s-siz[cr]);//-siz[cr]!!!
mn=N;getrt(v,,ts);solve(rt,ts);
}
}
int main()
{
scanf("%d",&n);int x,y;ll z;
for(int i=;i<n;i++)
{
scanf("%d%d%lld",&x,&y,&z);add(x,y,z);
}
scanf("%lld",&K);
mn=N;getrt(,,n);solve(rt,n);
printf("%d\n",ans);
return ;
}

应当排序后枚举两个指针。(代码中两种方法时间一样)

如果把 ts=s-siz[cr] 写成 ts=s-siz[v] ,就会T7个点(?)!!!

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=4e4+;
int n,hd[N],xnt,to[N<<],nxt[N<<],w[N<<],siz[N],mn,rt,sta[N],top,K,ans;
bool vis[N];
void add(int x,int y,int z)
{
to[++xnt]=y;nxt[xnt]=hd[x];w[xnt]=z;hd[x]=xnt;
to[++xnt]=x;nxt[xnt]=hd[y];w[xnt]=z;hd[y]=xnt;
}
void getrt(int cr,int fa,int s)
{
siz[cr]=;int mx=;
for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)
{
getrt(v,cr,s);siz[cr]+=siz[v];mx=max(mx,siz[v]);
}
mx=max(mx,s-siz[cr]);
if(mx<mn)mn=mx,rt=cr;
}
void dfs(int cr,int fa,int lj)
{
sta[++top]=lj;
for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]]&&v!=fa)
dfs(v,cr,lj+w[i]);
}
int calc(int cr,int w)
{
int ret=;dfs(cr,,w);
// l=1;r=0;
// sort(sta+l,sta+r+1);
// while(l<=r)
// if(sta[l]+sta[r]<=K)ret+=r-l,l++;
// else r--;
sort(sta+,sta+top+);int p=top;
for(int i=;i<=top;i++)
{
while(sta[p]+sta[i]>K&&p)p--;if(!p)break;
ret+=p-(p>=i);
}
top=;
// printf("cr=%d ret=%d\n",cr,ret);
return ret>>;
}
void solve(int cr,int s)
{
// printf("rt=%d\n",cr);
vis[cr]=;
ans+=calc(cr,);
// printf("cr=%d ans=%d\n",cr,ans);
for(int i=hd[cr],v;i;i=nxt[i]) if(!vis[v=to[i]])
{
ans-=calc(v,w[i]);
int ts=(siz[cr]>siz[v]?siz[v]:s-siz[cr]);//s-siz[cr]!!!
mn=N;getrt(v,,ts);solve(rt,ts);
}
}
int main()
{
scanf("%d",&n);
for(int i=,x,y,z;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);add(x,y,z);
}
scanf("%d",&K);
mn=N;getrt(,,n);solve(rt,n);
printf("%d\n",ans);
return ;
}