「WC2010」重建计划(长链剖分/点分治)

时间:2021-11-28 20:01:53

「WC2010」重建计划(长链剖分/点分治)

题目描述

有一棵大小为 \(n\) 的树,给定 \(L, R\) ,要求找到一条长度在 \([L, R]\) 的路径,并且路径上边权的平均值最大

\(1 \leq n,L,R \leq 10^5\)

解题思路 :

前几天沉迷初赛来写几道数据结构恢复一下代码能力,坑填完之后可能就要开始啃思维题了QwQ。

这个题貌似长链剖分和点分复杂度都是 \(O(nlog^2n)\) 的,点分好久都没碰了,长链剖分也只有暑假里口胡了几个多校的题而已,先讲做法吧

这个题很显然可以分数规划,二分答案后问题转化为每条边边权变为\(w_i - mid\) ,判断能不能找到一条长度在 \([L, R]\) 且边权和非负的路径,实际上可以求一条满足条件且边权和最大的路。

点分治:

比较显然的做法是对于每一个分治中心用一个数据结构来维护到点分中心的每一种长度的路径的最值,然后暴力拼合路径即可,这样做的话总复杂度是 \(O(nlog^3n)\) ,感觉不太能松的过去。

简单观察发现,对于当前长度 \(i\) ,随着 \(i\) 递增可行的区间是单调左移的,所以可以用一个单调队列来处理每一个分治中心的答案。不过要注意的是,总的区间大小等价于当前处理到的最长的路径长度,如果先处理一条很长的路径,剩下的点数很多的话复杂度就退化到 \(O(n^2logn)\) 了。

解决方法是对每一个儿子按照其最长路径长度排序(也就是最大深度),这样处理每一个儿子的复杂度不会超过其最长路径长度,复杂度就是正确的 \(O(nlog^2n)\) 了,类似的套路也可以在 \(\text{BZOJ Normal}\) 见到

/*program by mangoyang*/
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
    int f = 0, ch = 0; x = 0;
    for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
    for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
    if(f) x = -x;
}
const int N = 200005;
const double eps = 1e-4;

int a[N], b[N], nxt[N], head[N], cnt, n, L, R;

inline void add(int x, int y, int z){
    a[++cnt] = y, b[cnt] = z, nxt[cnt] = head[x], head[x] = cnt;
}

namespace Tree{
    double f[N], res, mid;
    vector<int> g[N];
    vector<double> c[N];
    int sz[N], q[N], vis[N], all, mn, firt, root;

    inline void getroot(int u, int fa){
        int msize = 0; sz[u] = 1;
        for(int p = head[u]; p; p = nxt[p]){
            int v = a[p];
            if(v == fa || vis[v]) continue;
            getroot(v, u), sz[u] += sz[v];
            if(sz[v] > msize) msize = sz[v];
        }
        msize = max(msize, all - sz[u]);
        if(msize <= mn) mn = msize, root = u;
    }
    inline void build(int u){
        int ls = all; vis[u] = 1;
        for(int p = head[u]; p; p = nxt[p]){
            int v = a[p];
            if(vis[v]) continue;
            mn = all = sz[v] > sz[u] ? ls - sz[u] : sz[v];
            getroot(v, u), g[u].push_back(root), build(root);
        }
    }
    inline void realmain(){
        mn = all = n, getroot(1, 0), build(firt = root);
    }   
    inline bool cmp(int A, int B){ 
        return c[A].size() < c[B].size(); 
    }
    inline void getdis(int u, int fa, int x, int dep, double dis){
        if(c[x].size() < dep + 1) c[x].push_back(dis);
        else c[x][dep] = max(c[x][dep], dis);
        for(int p = head[u]; p; p = nxt[p]){
            int v = a[p];
            if(!vis[v] && v != fa) 
                getdis(v, u, x, dep + 1, dis + b[p] - mid);
        }
    }
    inline void solve(int u){
        vis[u] = 1;     
        vector<int> now; now.clear(); int len = 0;
        for(int p = head[u]; p; p = nxt[p]){
            int v = a[p];
            if(vis[v]) continue;
            c[v].clear(), getdis(v, u, v, 0, b[p] - mid);
            now.push_back(v), len = Max(len, c[v].size());
        }
        sort(now.begin(), now.end(), cmp);
        for(int i = 1; i <= len; i++) f[i] = -inf;
        int mx = 0;
        for(int pos = 0; pos < now.size(); pos++){
            int x = now[pos], h = 1, t = 0, p = 0;
            for(int i = c[x].size() - 1; ~i; i--){
                while(h <= t && q[h] + i + 1 < L) h++;
                while(p <= mx && i + p + 1 < L) p++;
                while(p <= mx && i + p + 1 <= R){
                    while(h <= t && f[p] >= f[q[t]]) t--;
                    q[++t] = p, p++;
                }
                if(h <= t) res = max(res, c[x][i] + f[q[h]]); 
            }
            for(int i = 0; i < c[x].size(); i++) 
                f[i+1] = max(f[i+1], c[x][i]);
            mx = c[x].size();
        }
        for(int i = 0; i < g[u].size(); i++) solve(g[u][i]);    
    }
    inline bool check(double x){
        mid = x, res = -inf;
        memset(vis, 0, sizeof(vis)), solve(firt);
        return res >= eps;
    }   
}

int main(){
    read(n), read(L), read(R);
    for(int i = 1, x, y, z; i < n; i++){
        read(x), read(y), read(z);
        add(x, y, z), add(y, x, z);
    }
    Tree::realmain();
    double l = 0, r = 1000000, ans = 0;
    while(l + eps < r){
        double mid = (l + r) / 2.0;
        if(Tree::check(mid)) l = mid, ans = mid; else r = mid;
    }
    printf("%.3lf", ans);
    return 0;
}

长链剖分

本质上点分做的事情是每次合并多个以深度为下标的数组求最值,观察发现这个题可以长链剖分。每一次保留重链的向上路径,并把轻链的信息一一合并上去。

但是这里就不能用单调队列优化了,因为重链是一开始就要被保留的,如果用单调队列的话复杂度会被同样的东西卡掉,所以必须要用线段树来维护答案。

具体实现的话只需要维护一下 \(tag\) 标记实现在重链的链头加点,每次在线段树里面查指定范围内的最值合并即可,复杂度也是 \(O(nlog^2n)\)

/*program by mangoyang*/
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
    int f = 0, ch = 0; x = 0;
    for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
    for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
    if(f) x = -x;
}
const int N = 200005;
const double eps = 1e-4;
double tag[N], f[N], ans, mid;
int a[N], b[N], nxt[N], head[N], cnt; 
int dfn[N], dep[N], ms[N], w[N], tot, n, L, R;

inline void add(int x, int y, int z){
    a[++cnt] = y, b[cnt] = z, nxt[cnt] = head[x], head[x] = cnt;
}
struct SegmentTree{
    #define lson (u << 1)
    #define rson (u << 1 | 1)
    double s[N<<2];
    inline void clear(){
        for(int i = 0; i < (N << 2); i++) s[i] = -inf; 
    }
    inline void modify(int u, int l, int r, int pos, double x){ 
        if(l == r) return (void) (s[u] = max(s[u], x));
        int mid = l + r >> 1; 
        if(pos <= mid) modify(lson, l, mid, pos, x);
        else modify(rson, mid + 1, r, pos, x); s[u] = Max(s[lson], s[rson]);
    }
    inline double query(int u, int l, int r, int L, int R){
        if(l >= L && r <= R) return s[u];
        int mid = l + r >> 1; double res = -inf;
        if(L <= mid) res = max(res, query(lson, l, mid, L, R));
        if(mid < R) res = max(res, query(rson, mid + 1, r, L, R));
        return res;
    }
}Seg;
inline void dfs(int u, int fa){
    for(int p = head[u]; p; p = nxt[p]){
        int v = a[p];
        if(v == fa) continue;
        dfs(v, u);
        if(dep[v] >= dep[ms[u]]) ms[u] = v, w[u] = b[p];
        if(dep[v] + 1 > dep[u]) dep[u] = dep[v] + 1;
    }
}
inline void split(int u, int fa){
    if(!dfn[u]) dfn[u] = ++tot; int pu = dfn[u];
    if(ms[u]) split(ms[u], u), tag[pu] = tag[pu+1] + w[u] - mid;
    Seg.modify(1, 1, n, pu, f[pu] = -tag[pu]);
    if(L <= dep[u]){
        double tmp = Seg.query(1, 1, n, pu + L, pu + min(dep[u], R));
        ans = max(ans, tmp + tag[pu]);
    }
    for(int p = head[u]; p; p = nxt[p]){
        int v = a[p], pv = dfn[v];
        if(v == fa || v == ms[u]) continue;
        split(v, u); 
        for(int i = 0; i <= dep[v]; i++){
            int l = pu + max(0, L - i - 1), r = pu + min(dep[u], R - i - 1); 
            double tmp = Seg.query(1, 1, n, l, r);
            ans = max(ans, tmp + tag[pv] + tag[pu] + f[pv+i] + b[p] - mid);
        }
        for(int i = 0; i <= dep[v]; i++){
            double tmp = tag[pv] + f[pv+i] + b[p] - mid - tag[pu];
            if(tmp > f[pu+i+1]) Seg.modify(1, 1, n, pu + i + 1, f[pu+i+1] = tmp);
        }
    }
}
inline bool check(){
    Seg.clear();
    ans = -inf, split(1, 0); return ans >= eps; 
}
int main(){
    read(n), read(L), read(R);
    for(int i = 1, x, y, z; i < n; i++){
        read(x), read(y), read(z);
        add(x, y, z), add(y, x, z);
    }
    dfs(1, 0);
    double l = 0, r = 1000000, realans = 0; 
    while(l + eps < r){
        mid = (l + r) / 2.0;
        if(check()) l = mid, realans = mid; else r = mid;
    }
    printf("%.3lf", realans);
    return 0;
}