CF888G Xor-MST 生成树、分治、Trie树合并

时间:2022-12-28 06:39:50

传送门


第一次接触到Boruvka求最小生成树

它的原版本是:初始每一个点构成一个连通块,每一次找到每一个连通块到其他的连通块权值最短的边,然后合并这两个连通块。因为每一次连通块个数至少减半,所以复杂度是\(O((n+m)logn)\)的

虽然它的原版本用途不多,但是思想可以涵盖很多其他题目,比如这道题

可以想到一个做法:将所有权值插入一个\(Trie\)里,在每一个叶子节点维护到达这个节点的数的编号。像上面那样维护若干连通块,每一次计算权值最小的边时,将当前连通块中所有权值从Trie中删去,然后对于连通块中的每个权值在\(Trie\)上找到异或和最小的数字和编号,最后连边、恢复原来的\(Trie\)。

复杂度\(O(nlog^2n)\),但常数太大,哪怕在\(CF\)的神机下大数据也会直接沦陷QAQ

正着不行,就反着考虑。设能够产生贡献的二进制最高位为\(k\),即对于所有数来说,存在第\(k\)位为\(0\)的数,也存在第\(k\)位为\(1\)的数,且对于\(>k\)的数均不满足这一条件。那么最优的连边方法显然是:这一位为\(1\)的数之间连成一个生成树,这一位为\(0\)的数之间连成一个生成树,然后在这两个点集之间连一条边。可以发现这个问题变成了两个子问题,且对于这两个子问题的\(k\)一定会小于当前问题的\(k\),所以可以直接递归下去。

考虑如何计算当前层连的边的贡献。不妨让每一层递归结束时把当前层所有权值对应的\(Trie\)树建好传给上面一层,那么每一层可以获得这一位为\(1\)的所有数的\(Trie\)和这一位为\(0\)的所有数的\(Trie\)。将点数较少的点集中所有点的权值放在点数较多的点集对应的\(Trie\)上跑最小值,就可以得到当前层连边的权值大小。计算完贡献后将两个\(Trie\)用类似线段树合并的方式合并,可以有效避免\(MLE\)。

总复杂度仍然是\(O(nlog^2n)\)但跑得快了不少。

#include<iostream>
#include<cstdio>
#include<cctype>
#include<algorithm>
#include<cstring>
#include<iomanip>
#include<vector>
#include<set>
//This code is written by Itst
using namespace std;

inline int read(){
    int a = 0;
    char c = getchar();
    while(!isdigit(c))
        c = getchar();
    while(isdigit(c)){
        a = a * 10 + c - 48;
        c = getchar();
    }
    return a;
}

const int MAXN = 2e5 + 3;

struct node{
    node *ch[2];
    node(){ch[0] = ch[1] = NULL;}
};

struct Trie{
    node *rt = new node;

    void ins(int x){
        node *cur = rt;
        for(int i = 29 ; i >= 0 ; --i){
            if(cur->ch[(bool)(x & (1 << i))] == NULL)
                cur->ch[(bool)(x & (1 << i))] = new node;
            cur = cur->ch[(bool)(x & (1 << i))];
        }
    }

    int query(int x){
        int ans = 0;
        node *cur = rt;
        for(int i = 29 ; i >= 0 ; --i){
            bool f = x & (1 << i);
            if(cur->ch[f] != NULL)
                cur = cur->ch[f];
            else{
                cur = cur->ch[!f];
                ans += 1 << i;
            }
        }
        return ans;
    }
};
int N;
long long sum;
vector < int > val;

node* merge(node *A , node *B){
    if(A == NULL) return B;
    if(B == NULL) return A;
    A->ch[0] = merge(A->ch[0] , B->ch[0]);
    A->ch[1] = merge(A->ch[1] , B->ch[1]);
    return A;
}

Trie merge(Trie A , Trie B){
    A.rt = merge(A.rt , B.rt);
    return A;
}

Trie solve(vector < int > val , int now){
    if(val.empty()) return Trie();
    if(now < 0){
        Trie t;
        t.ins(val[0]);
        return t;
    }
    vector < int > lft , rht;
    for(auto t : val)
        t & (1 << now) ? rht.push_back(t) : lft.push_back(t);
    Trie L = solve(lft , now - 1) , R = solve(rht , now - 1);
    if(lft.size() < rht.size()){
        swap(lft , rht);
        swap(L , R);
    }
    int minN = 2e9;
    for(auto t : rht) minN = min(minN , L.query(t));
    if(!rht.empty()) sum += minN;
    return merge(L , R);
}

int main(){
#ifndef ONLINE_JUDGE
    freopen("in","r",stdin);
    //freopen("out","w",stdout);
#endif
    N = read();
    for(int i = 1 ; i <= N ; ++i)
        val.push_back(read());
    sort(val.begin() , val.end());
    solve(val , 29);
    cout << sum;
    return 0;
}