HDOJ 3938 Portal (离线并查集)

时间:2021-07-21 11:10:54

题意

给出一个带边权的图,令两个点之间的路径的费用为中途经过的边的最大值,对每个查询求有多少对点路径费用小于等于给定的L。

思路

用类似kruskal的思想,每个点都设置一个sum数组表示它所相连的所有满足小于当前L的边相连的点有多少个(其实数组名用size更合适)。
然后对于每次查询L,因为边我们也是升序排序的,所以对于所有小于L的边都有sum[find(edge[i].u)]*sum[find(edge[i].v)]对点满足条件,(iff find(edge[i].u) != find(edge[i].v))

代码

#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <string>
#include <math.h>
#include <stdlib.h>
using namespace std;
#define LL long long
#define Lowbit(x) ((x)&(-x))
#define lson l, mid, rt << 1
#define rson mid + 1, r, rt << 1|1
#define MP(a, b) make_pair(a, b)
const int INF = 0x3f3f3f3f;
const int Mod = 1000000007;
const int maxn = 50000 + 7;
const double eps = 1e-8;
const double PI = acos(-1.0);
typedef pair<int, int> pii;

struct Edge
{
    int u, v, w;
}edge[maxn];

struct Q
{
    int l, id;
}query[maxn];

int n, m, q;
int fa[maxn];
LL sum[maxn], ans[maxn];

void init()
{
    for (int i = 1; i <= n; i++)
        fa[i] = i, sum[i] = 1;
    memset(ans, 0, sizeof(ans));
}

int find(int x)
{
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}

void merge(int u, int v)
{
    fa[u] = v;
    sum[v] += sum[u];
}

int main()
{
    //freopen("H:\\in.txt","r",stdin);
    //freopen("H:\\out.txt","w",stdout);
    while (scanf("%d%d%d", &n, &m, &q) != EOF)
    {
        init();
        for (int i = 0; i < m; i++)
            scanf("%d%d%d", &edge[i].u, &edge[i].v, &edge[i].w);
        for (int i = 0; i < q; i++)
            scanf("%d", &query[i].l), query[i].id = i;
        sort(edge, edge + m, [](Edge a, Edge b){return a.w < b.w;});
        sort(query, query + q, [](Q a, Q b){return a.l < b.l;});
        int cnt = 0;
        for (int i = 0; i < q; i++)
        {
            int id = query[i].id;
            if (i != 0) ans[id] = ans[query[i-1].id];
            while (cnt < m && edge[cnt].w <= query[i].l)
            {
                int ra = find(edge[cnt].u);
                int rb = find(edge[cnt].v);
                if (ra != rb)
                {
                    ans[id] += sum[ra] * sum[rb];
                    merge(ra, rb);
                }
                cnt++;
            }
        }
        for (int i = 0; i < q; i++)
            printf("%I64d\n", ans[i]);
    }
    return 0;
}