HDU5877WeakPair(线段树+离散化+DFS)

时间:2021-11-03 19:26:02

解题思路:

【题意】

给你一棵有根树,一个定值k,以及树上每个结点的值a[i]

对于有序对(u,v),如果(1)u是v的祖先,且(2)a[u]*a[v]<=k,则称该有序对(u,v)是弱的

问树中有多少对有序对(u,v)是弱的

【类型】

离散化+dfs+树状数组

【分析】

对于要求(1),u是v的祖先,我们可以采取dfs

遍历到v时,它上方的所有结点必定都是满足第一条件的u

熟悉dfs过程的应该能理解这一点,不理解的可以借助下述图片稍微理解一下

HDU5877WeakPair(线段树+离散化+DFS)

从上图中,我们可以大致看出dfs过程是从树根开始向树叶访问的

对于某结点v,它的祖先u肯定是先于它被访问的,不然也不可能到达结点v

正如上图,结点10的祖先是结点1,2,4,8,不管哪个祖先,一旦有一个没被访问,也不可能达到结点10

此外,在退出某个子树的时候,该子树下结点的影响会被消除,这样就能保证所有有影响的都是祖先

HDU5877WeakPair(线段树+离散化+DFS)

HDU5877WeakPair(线段树+离散化+DFS)

要求(2),a[u]*a[v]<=k,那么到v的时候,所有小于等于k/a[v]的u都满足,可以想到树状数组

结点的值a[i]最大10亿,要用树状数组的话肯定要离散化

离散化的时候要把k/a[v]加进去一起离散,保证大小关系不变

另外,当a[i]=0时,会出现除以0错误,所以我们要特判该情况

显然a[i]=0的话,任何满足要求(1)的结点都可以构成弱的有序对

所以将该条件下的k/a[i]的结果直接设置为inf

#include<bits/stdc++.h>
using namespace std;
#define INF (1ll<<60)-1
#define LL long long
#define N 100005
LL k, ans, a[N], b[N*2], sum[N<<4];
int deep[N], head[N], tol, m;
struct Edge
{
    int v, nxt;
}edge[N];
void init()
{
    ans = tol = 0;
    memset(sum, 0, sizeof(sum));
    memset(deep, 0, sizeof(deep));
    memset(head, -1, sizeof(head));
}
void build(int rt, int left, int right)
{
    if(left == right)
    {
        sum[rt] = 0;
        return ;
    }
    int mid = (left+right)>>1;
    build(rt<<1, left, mid);
    build(rt<<1|1, mid+1, right);
    sum[rt] = sum[rt<<1] + sum[rt<<1|1];
}
void addedge(int u, int v)
{
    edge[tol].v = v;
    edge[tol].nxt = head[u];
    head[u] = tol++;
}
int query(int rt, int left, int right, int l, int r)
{
    if(l<=left&&r>=right) return sum[rt];
    int mid = (left + right) >> 1;
    if(r <= mid) return query(rt<<1, left, mid, l, r);
    else if(l > mid) return query(rt<<1|1, mid+1, right, l, r);
    else return query(rt<<1, left, mid, l, r) + query(rt<<1|1, mid+1, right, l, r);
}
void update(int rt, int left, int right, int pos, int val)
{
    if(left == right)
    {
        sum[rt] += val;
        return ;
    }
    int mid = (left+right)>>1;
    if(pos <= mid) update(rt<<1, left, mid, pos, val);
    else update(rt<<1|1, mid+1, right, pos, val);
    sum[rt] = sum[rt<<1]+sum[rt<<1|1];
}
void dfs(int u)
{
    LL lim;
    if(a[u] == 0) lim = INF;
    else lim = k/a[u];
    int l = lower_bound(b+1, b+m+1, lim) - b;
    int pos = lower_bound(b+1, b+m+1, a[u]) - b;
    ans += query(1, 1, m, 1, l);
    update(1, 1, m, pos, 1);
    for(int i = head[u]; i != -1; i = edge[i].nxt) dfs(edge[i].v);
    update(1, 1, m, pos, -1);
}
void solve()
{
    int n;
    init();
    scanf("%d%I64d", &n, &k);
    for(int i = 1; i <= n;i++)
    {
        scanf("%I64d", &a[i]);
        b[i] = a[i];
        if(a[i]!=0)
            b[i+n] = k/a[i];
        else
            b[i+n] = INF;
    }
    sort(b+1, b+n*2+1);
    m = unique(b+1, b+n*2+1) - (b+1);
    build(1, 1, m);
    for(int i = 1; i < n; i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        addedge(u, v);
        deep[v]++;
    }
    for(int i = 1; i <= n; i++)
    {
        if(deep[i] == 0)
        {
            dfs(i);
            break;
        }
    }
    printf("%I64d\n", ans);
}
int main()
{
    int t;
    scanf("%d", &t);
    while(t--)
    {
        solve();
    }
    return 0;
}