思路:首先如果颜色相同直接利用以前的答案即可,可以离线排序或是在线hash,然后考虑怎么快速统计答案。
首先如果点a是点b的祖先,那么一定有点b在以点a为根的子树的dfs序区间内的,于是先搞出dfs序。
然后如果颜色a的点数很小,颜色b的点数很大,那么可以考虑枚举a的点数,然后对于每一种颜色开个vector记录一下有哪些点是这种颜色,然后按照它们的dfs序排序,就可以用颜色a中的每个点在颜色b中二分出哪些点属于以该点为根的子树对应的dfs序区间了。复杂度O(size(a)*log(size(b))),size(a)表示颜色a的vector的size()。
然后如果颜色b的点数很小,颜色a的点数很大,那么就枚举b的点数,这时要考虑的问题就成了一个点被多少段区间覆盖了,然后离散化差分预处理,再去二分(我写的是vector的离散化)。复杂度O(size(b)*log(size(a)))
但如果a,b的点数差不多且都很大(也就是几乎为sqrt(n)),那么算法复杂度就会变成O(sqrt(n)*log(n))了,再乘以一个q就会GG,于是只能另寻他法,然后可以发现直接两个指针扫过去,一个扫区间端点另一个扫要询问的点,然后如果扫到一个点就直接统计答案,然后这就变成了O(size(a)+size(b))了。
那这个很大是有多大,很小是有多小呢?
对于第一种算法使用条件是size(b)>x,第二种算法使用条件是size(a)>x,其余则用第三种算法。
对于第一、二种情况,时间复杂度最大是O(n^2logn/x),然后对于第三种则是O(n*x),然后根据基本不等式x=sqrt(nlogn),总时间复杂度为O(n*sqrt(nlogn))。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<cmath>
using namespace std;
#define maxn 200005
#define maxr 30000 int n,r,Q,tot,cnt;
int now[maxn],pre[*maxn],son[*maxn],color[maxn],dfn[maxn],size[maxn];
long long ans[maxn]; inline int read(){
int x=,f=;char ch=getchar();
for (;ch<''||ch>'';ch=getchar()) if (ch=='-') f=-;
for (;ch>=''&&ch<='';ch=getchar()) x=x*+ch-'';
return x*f;
} struct node{
int dfn,bo;
node(){}
node(int a,int b){dfn=a,bo=b;}
bool operator <(const node &a)const{return dfn<a.dfn;}
}; struct query{
int x,y,id;
bool operator <(const query &a)const{return x<a.x||(x==a.x&&y<a.y);}
}q[maxn]; bool cmp(int a,int b){return dfn[a]<dfn[b];} vector<int> col[maxr],val[maxr];
vector<node> v[maxr];
vector<int> fuckpps[maxr]; void add(int a,int b){
son[++tot]=b;
pre[tot]=now[a];
now[a]=tot;
} void link(int a,int b){
add(a,b),add(b,a);
} void dfs(int x,int fa){
dfn[x]=++cnt;
for (int p=now[x];p;p=pre[p])
if (son[p]!=fa) dfs(son[p],x),size[x]+=size[son[p]]+;
} int binary_search(int l,int r,int b,int pos){
int ans=-;
while (l<=r){
int mid=(l+r)>>;
if (pos>=fuckpps[b][mid]) ans=mid,l=mid+;
else r=mid-;
}
return ans+;
} long long solve1(int a,int b){
long long ans=;
for (unsigned int i=;i<col[a].size();i++){
int x=col[a][i],l=binary_search(,fuckpps[b].size()-,b,dfn[x]-),r=binary_search(,fuckpps[b].size()-,b,dfn[x]+size[x]);
ans+=r-l;
}
return ans;
} int binary_search2(int l,int r,int b,int pos){
int ans=-;
while (l<=r){
int mid=(l+r)>>;
if (v[b][mid].dfn<=pos) ans=mid,l=mid+;
else r=mid-;
}
return ans;
} long long solve2(int a,int b){
long long ans=;
for (unsigned int i=;i<col[b].size();i++){
int x=col[b][i],pos=binary_search2(,v[a].size()-,a,dfn[x]);
if (pos!=-) ans+=val[a][pos];
}
return ans;
} long long solve3(int a,int b){
long long ans=;unsigned int i=,j=,tt=;
while (i<v[a].size() && j<col[b].size())
if (v[a][i].dfn<=dfn[col[b][j]]) tt=val[a][i],i++;else ans+=tt,j++;
return ans;
} int main(){
n=read(),r=read(),Q=read();int siz=sqrt(n*log2(n));
for (int i=,x;i<=n;i++){
if (i!=) x=read(),link(i,x);
color[i]=read();col[color[i]].push_back(i);
}
dfs(,);
for (int i=;i<=n;i++) fuckpps[color[i]].push_back(dfn[i]);
for (int i=;i<=r;i++) sort(col[i].begin(),col[i].end(),cmp),sort(fuckpps[i].begin(),fuckpps[i].end());
for (int i=;i<=r;i++){
for (unsigned int j=;j<col[i].size();j++)
v[i].push_back(node(dfn[col[i][j]],)),v[i].push_back(node(dfn[col[i][j]]+size[col[i][j]]+,-));
sort(v[i].begin(),v[i].end());int sum=;
for (unsigned int j=;j<v[i].size();j++){
sum+=v[i][j].bo;
val[i].push_back(sum);
}
}
for (int i=;i<=Q;i++) q[i].x=read(),q[i].y=read(),q[i].id=i;
sort(q+,q+Q+);
for (int i=;i<=Q;i++){
if (q[i].x==q[i-].x && q[i].y==q[i-].y){ans[q[i].id]=ans[q[i-].id];continue;}
if (col[q[i].y].size()+>=siz&&col[q[i].x].size()+<siz) ans[q[i].id]=solve1(q[i].x,q[i].y);
else if (col[q[i].x].size()+>=siz&&col[q[i].y].size()+<siz) ans[q[i].id]=solve2(q[i].x,q[i].y);
else ans[q[i].id]=solve3(q[i].x,q[i].y);
}
for (int i=;i<=Q;i++) printf("%lld\n",ans[i]);
return ;
}