解题思路:
【题意】
给你一棵有根树,一个定值k,以及树上每个结点的值a[i]
对于有序对(u,v),如果(1)u是v的祖先,且(2)a[u]*a[v]<=k,则称该有序对(u,v)是弱的
问树中有多少对有序对(u,v)是弱的
【类型】
离散化+dfs+树状数组
【分析】
对于要求(1),u是v的祖先,我们可以采取dfs
遍历到v时,它上方的所有结点必定都是满足第一条件的u
熟悉dfs过程的应该能理解这一点,不理解的可以借助下述图片稍微理解一下
从上图中,我们可以大致看出dfs过程是从树根开始向树叶访问的
对于某结点v,它的祖先u肯定是先于它被访问的,不然也不可能到达结点v
正如上图,结点10的祖先是结点1,2,4,8,不管哪个祖先,一旦有一个没被访问,也不可能达到结点10
此外,在退出某个子树的时候,该子树下结点的影响会被消除,这样就能保证所有有影响的都是祖先
要求(2),a[u]*a[v]<=k,那么到v的时候,所有小于等于k/a[v]的u都满足,可以想到树状数组
结点的值a[i]最大10亿,要用树状数组的话肯定要离散化
离散化的时候要把k/a[v]加进去一起离散,保证大小关系不变
另外,当a[i]=0时,会出现除以0错误,所以我们要特判该情况
显然a[i]=0的话,任何满足要求(1)的结点都可以构成弱的有序对
所以将该条件下的k/a[i]的结果直接设置为inf
#include<bits/stdc++.h> using namespace std; #define INF (1ll<<60)-1 #define LL long long #define N 100005 LL k, ans, a[N], b[N*2], sum[N<<4]; int deep[N], head[N], tol, m; struct Edge { int v, nxt; }edge[N]; void init() { ans = tol = 0; memset(sum, 0, sizeof(sum)); memset(deep, 0, sizeof(deep)); memset(head, -1, sizeof(head)); } void build(int rt, int left, int right) { if(left == right) { sum[rt] = 0; return ; } int mid = (left+right)>>1; build(rt<<1, left, mid); build(rt<<1|1, mid+1, right); sum[rt] = sum[rt<<1] + sum[rt<<1|1]; } void addedge(int u, int v) { edge[tol].v = v; edge[tol].nxt = head[u]; head[u] = tol++; } int query(int rt, int left, int right, int l, int r) { if(l<=left&&r>=right) return sum[rt]; int mid = (left + right) >> 1; if(r <= mid) return query(rt<<1, left, mid, l, r); else if(l > mid) return query(rt<<1|1, mid+1, right, l, r); else return query(rt<<1, left, mid, l, r) + query(rt<<1|1, mid+1, right, l, r); } void update(int rt, int left, int right, int pos, int val) { if(left == right) { sum[rt] += val; return ; } int mid = (left+right)>>1; if(pos <= mid) update(rt<<1, left, mid, pos, val); else update(rt<<1|1, mid+1, right, pos, val); sum[rt] = sum[rt<<1]+sum[rt<<1|1]; } void dfs(int u) { LL lim; if(a[u] == 0) lim = INF; else lim = k/a[u]; int l = lower_bound(b+1, b+m+1, lim) - b; int pos = lower_bound(b+1, b+m+1, a[u]) - b; ans += query(1, 1, m, 1, l); update(1, 1, m, pos, 1); for(int i = head[u]; i != -1; i = edge[i].nxt) dfs(edge[i].v); update(1, 1, m, pos, -1); } void solve() { int n; init(); scanf("%d%I64d", &n, &k); for(int i = 1; i <= n;i++) { scanf("%I64d", &a[i]); b[i] = a[i]; if(a[i]!=0) b[i+n] = k/a[i]; else b[i+n] = INF; } sort(b+1, b+n*2+1); m = unique(b+1, b+n*2+1) - (b+1); build(1, 1, m); for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); addedge(u, v); deep[v]++; } for(int i = 1; i <= n; i++) { if(deep[i] == 0) { dfs(i); break; } } printf("%I64d\n", ans); } int main() { int t; scanf("%d", &t); while(t--) { solve(); } return 0; }