将每个人跑步的路径拆分成x->lca,lca->y两条路径分别考虑:
对于在点i的观察点,这个人(s->t)能被观察到的充要条件为:
1.直向上的路径:w[i]=dep[s]-dep[i],移项得w[i]+dep[i]=dep[s]
2.直向下的路径:w[i]=dep[s]-dep[lca]+dep[i]-dep[lca],移项得w[i]-dep[i]=dep[s]-2*dep[lca]。
问题转化为,对每个点i,统计它的子树中有多少个点x满足dep[x]=w[i]+dep[i]或dep[x]-2*dep[lca]=w[i]-dep[i],这是经典的线段树合并问题。
注意到并不是子树中所有满足条件的点都能被统计,因为有的点还没到观察点就往下跑了(lca深度大于当前观察点),差分解决。
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define lson ls[x],L,mid
#define rson rs[x],mid+1,R
#define rep(i,l,r) for (int i=l; i<=r; i++)
#define For(i,x) for (int i=h[x],k; i; i=nxt[i])
typedef long long ll;
using namespace std; const int N=,M=;
int n,m,u,v,cnt,s,t,w[N],d[N],ans[N],fa[N][],h[N],nxt[N<<],to[N<<];
void add(int u,int v){ to[++cnt]=v; nxt[cnt]=h[u]; h[u]=cnt; } struct T{
int nd,v[M],ls[M],rs[M],rt[N];
void ins(int &x,int L,int R,int pos,int k){
if (!x) x=++nd;
if (L==R){ v[x]+=k; return; }
int mid=(L+R)>>;
if (pos<=mid) ins(lson,pos,k); else ins(rson,pos,k);
} int merge(int x,int y,int L,int R){
if (!x || !y) return x+y;
if (L==R) { v[x]+=v[y]; return x; }
int mid=(L+R)>>;
ls[x]=merge(ls[x],ls[y],L,mid);
rs[x]=merge(rs[x],rs[y],mid+,R);
return x;
} int que(int x,int L,int R,int pos){
if (!x) return ;
if (L==R) return v[x];
int mid=(L+R)>>;
if (pos<=mid) return que(lson,pos); else return que(rson,pos);
}
}T1,T2; void dfs(int x){
rep(i,,) fa[x][i]=fa[fa[x][i-]][i-];
For(i,x) if ((k=to[i])!=fa[x][]) fa[k][]=x,d[k]=d[x]+,dfs(k);
} void dfs2(int x){
For(i,x) if ((k=to[i])!=fa[x][]){
dfs2(k);
T1.rt[x]=T1.merge(T1.rt[x],T1.rt[k],,n);
T2.rt[x]=T2.merge(T2.rt[x],T2.rt[k],,*n);
}
ans[x]+=(w[x]+d[x]>= && w[x]+d[x]<=n) ? T1.que(T1.rt[x],,n,w[x]+d[x]) : ;
ans[x]+=(w[x]-d[x]>=-n && w[x]-d[x]<=n) ? T2.que(T2.rt[x],,*n,w[x]-d[x]+n) : ;
} int Lca(int x,int y){
if (d[x]<d[y]) swap(x,y);
int t=d[x]-d[y];
for (int i=; ~i; i--) if (t&(<<i)) x=fa[x][i];
if (x==y) return x;
for (int i=; ~i; i--) if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][];
} int main(){
freopen("running.in","r",stdin);
freopen("running.out","w",stdout);
scanf("%d%d",&n,&m);
rep(i,,n) scanf("%d%d",&u,&v),add(u,v),add(v,u);
rep(i,,n) scanf("%d",&w[i]);
dfs();
rep(i,,m){
scanf("%d%d",&s,&t); int lca=Lca(s,t);
T1.ins(T1.rt[s],,n,d[s],); T1.ins(T1.rt[fa[lca][]],,n,d[s],-);
T2.ins(T2.rt[t],,*n,d[s]-*d[lca]+n,);
T2.ins(T2.rt[fa[lca][]],,*n,d[s]-*d[lca]+n,-);
if (d[s]-d[lca]==w[lca]) ans[lca]--;
}
dfs2();
rep(i,,n) printf("%d ",ans[i]); puts("");
return ;
}