题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5877
题意:
weak pair的要求:
1.u是v的祖先(注意不一定是父亲)
2.val[u]*val[k] <=k;
题解:
1.将val(以及k/val)离散化,可用map 或者 用数组。 只要能将val与树状数组或线段树的下标形成映射就可以了。
2.从根节点开始搜索(题目中的树,不是线段树或树状数组的树),先统计当前节点与祖先能形成多少对weak pair,然后将其插入到树状数组或线段树中。
3.递归其子树。递归完子树后,再把当前节点从树状数组或线段树中删去。因为:根据递归的特性,如果不删除,这个值将会残留在c数组, 那么他的堂兄弟,堂叔伯,堂侄子等(后一步 递归的)会误认为这个值是他的祖先的。所以要及时删除。
类似的题(边查询边更新):http://blog.csdn.net/dolfamingo/article/details/71001021
树状数组(map离散):
#include<cstdio>//hdu5877 树状数组 map离散 dfs
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<map>
#include<vector>
#define LL long long
#define INF 2e18 using namespace std; LL sval[200020], val[200020];
int n,fa[100010],c[200010];
map<LL,int> m;
vector<int> son[100010];
LL k,ans; int lowbit(int x)
{
return x & (-x);
} void add(int x, int d)
{
for(; x<=2*n; x += lowbit(x))
{
c[x] += d;
}
} int sumc(int x)
{
int s = 0;
for(;x>0; x -= lowbit(x))
{
s += c[x];
}
return s;
} void dfs(int rt)//c数组中的下标与val[i],k/val[i]映射 且k/v[i]的下标-i的下边等于n(自己定)
{
ans += sumc(m[val[n+rt]]);//统计<=k/val[rt]的个数,为什么不直接 m[k/val[rt]]? 因为val[rt]可能为0
add(m[val[rt]],1);
for(int i = 0; i<son[rt].size(); i++)
{
dfs(son[rt][i]);
}
add(m[val[rt]],-1);
} int main()
{
int T;
scanf("%d",&T);
while(T--)
{
scanf("%d%lld",&n,&k); for(int i = 1; i<=n; i++)
{
//存k/val[i]是因为 val[i]*(k/val[i])<= k. 到时可以直接从树状数组中统计<=k/val[i]的个数
scanf("%lld",&val[i]);
if(val[i])
val[n+i] = k/val[i];
else //0为特殊情况
val[n+i] = INF;
} //sval的作用是将v值按从小到大,一一与c数组的下标形成映射
for(int i = 1; i<=2*n; i++)
sval[i] = val[i];
sort(sval+1,sval+2*n+1); int cnt = 0;
m.clear();
for(int i = 1; i<=2*n; i++)
{
//map的作用是将v值与c数组的下标形成映射
if(!m[sval[i]]) m[sval[i]] = ++cnt;
} for(int i = 1; i<=n; i++)
fa[i] = 0, son[i].clear();
for(int i = 1,u,v; i<n; i++)
{
scanf("%d%d",&u,&v);
son[u].push_back(v);
fa[v] = u;
} ans = 0;
memset(c,0,sizeof(c));
for(int i = 1; i<=n; i++)
{
if(!fa[i])
{
dfs(i);
break;
}
}
printf("%lld\n",ans); }
return 0;
}
线段树(map离散):
#include<cstdio>//hdu5877 线段树 map离散 dfs
#include<cstring>//注意区分题目的树和线段树的树
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<vector>
#include<map>
#define LL long long
#define INF 2e18 using namespace std; int n,len;//n是题目给出的树的结点个数,len是线段树的线段长度。
LL k,ans;
LL val[200100],sval[200100];//val记录原始值,sval记录经过排序,删除重复操作的值,用于线段树的操作
int fa[100100],sum[800100];
vector<int>son[100100];
map<LL,int>m; int query(int root, int le, int ri, int x, int y)
{
if(x<=le && y>=ri)
return sum[root]; int mid = (le+ri)/2, ret = 0;
if(x<=mid) ret += query(root*2,le,mid,x,y);
if(y>=mid+1) ret += query(root*2+1,mid+1,ri,x,y);
return ret;
} void update(int root, int le, int ri, int pos, int d)
{
if(le==ri)
{
sum[root] += d;
return;
} int mid = (le+ri)/2;
if(pos<=mid) update(root*2,le,mid,pos,d);
else update(root*2+1,mid+1,ri,pos,d);
sum[root] = sum[root*2] + sum[root*2+1];
} void dfs(int rt)
{
int last = m[val[n+rt]];
int pos = m[val[rt]]; ans += query(1,1,len,1,last); update(1,1,len,pos,1);
for(int i = 0; i<son[rt].size(); i++)
{
dfs(son[rt][i]);
}
update(1,1,len,pos,-1);
} int main()
{
int T;
scanf("%d",&T);
while(T--)
{
scanf("%d%lld",&n,&k);
for(int i = 1; i<=n; i++)
{
scanf("%lld",&val[i]);
if(val[i])
val[n+i] = k/val[i];
else
val[n+i] = INF;
} for(int i = 1; i<=2*n; i++)
sval[i] = val[i];
sort(sval+1,sval+2*n+1); m.clear();
len = 0;
for(int i = 1; i<=2*n; i++)
{
if(!m[sval[i]]) m[sval[i]] = ++len;
} for(int i = 1; i<=n; i++)
fa[i] = 0, son[i].clear();
for(int i = 1,u,v; i<n; i++)
{
scanf("%d%d",&u,&v);
son[u].push_back(v);
fa[v] = u;
} ans = 0;
memset(sum,0,sizeof(sum));
for(int i = 1; i<=n; i++)
{
if(!fa[i])
{
dfs(i);
break;
}
} printf("%lld\n",ans);
}
return 0;
}
线段树(数组离散):
#include<cstdio>//hdu5877 线段树 dfs 普通数组进行离散
#include<cstring>//注意区分题目的树和线段树的树
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<vector>
#define LL long long
#define INF 2e18
using namespace std; int n,m;//n是题目给出的树的结点个数,m是线段树的线段长度。
LL k,ans;
LL val[200100],sval[200100];//val记录原始值,sval记录经过排序,删除重复操作的值,用于线段树的操作
int fa[100100],sum[800100];
vector<int>son[100100]; int query(int root, int le, int ri, int x, int y)
{
if(x<=le && y>=ri)
return sum[root]; int mid = (le+ri)/2, ret = 0;
if(x<=mid) ret += query(root*2,le,mid,x,y);
if(y>=mid+1) ret += query(root*2+1,mid+1,ri,x,y);
return ret;
} void update(int root, int le, int ri, int pos, int d)
{
if(le==ri)
{
sum[root] += d;
return;
} int mid = (le+ri)/2;
if(pos<=mid) update(root*2,le,mid,pos,d);
else update(root*2+1,mid+1,ri,pos,d);
sum[root] = sum[root*2] + sum[root*2+1];
} void dfs(int rt)
{
int last = lower_bound(sval+1, sval+m+1,val[n+rt]) - sval;
int pos = lower_bound(sval+1,sval+m+1,val[rt]) - sval; ans += query(1,1,m,1,last); update(1,1,m,pos,1);
for(int i = 0; i<son[rt].size(); i++)
{
dfs(son[rt][i]);
}
update(1,1,m,pos,-1);
} int main()
{
int T;
scanf("%d",&T);
while(T--)
{
scanf("%d%lld",&n,&k);
for(int i = 1; i<=n; i++)
{
scanf("%lld",&val[i]);
if(val[i])
val[n+i] = k/val[i];
else
val[n+i] = INF;
} for(int i = 1; i<=2*n; i++)
sval[i] = val[i];
sort(sval+1,sval+2*n+1);
m = unique(sval+1,sval+2*n+1) - (sval+1); for(int i = 1; i<=n; i++)
fa[i] = 0, son[i].clear();
for(int i = 1,u,v; i<n; i++)
{
scanf("%d%d",&u,&v);
son[u].push_back(v);
fa[v] = u;
} ans = 0;
memset(sum,0,sizeof(sum));
for(int i = 1; i<=n; i++)
{
if(!fa[i])
{
dfs(i);
break;
}
} printf("%lld\n",ans);
}
return 0;
}