NOIP模拟测试10 Problem B:模板:树上启发式合并

时间:2022-02-27 20:46:27

最开始以为是线段树合并,结果邓鸽鸽说线段树合并必死,布星。

热情的邓鸽鸽向我传授了我没有学习过的船新算法 树上启发式合并

学完之后发现就是很普通的启发式合并用到了树上而已

先说做法,给每个节点按时间轴开个动态开点线段树,节点保存种类和球数。易于发现每个点的球与它的子树有关。先把修改操作保存起来,开始处理整棵树。对于每个节点,我们需要判断这个球有没有出现在桶,因为这影响了我们对于种类的统计。于是再开一个数组存每个球最早出现的时间,只有最早出现的球才是有贡献的。

这道题的瓶颈在于答案的统计。可以发现我们算的每个点的答案都是子树答案的叠加,暴力的做法是遍历所有子树,这个复杂度是\(O(N^2)\)的,是30分暴力。

这太暴力力,,,于是就有了树上启发式合并,它没有某谷日报说的那么玄乎,其实和普通的启发式合并差不多。直接说优化方法:我们暴力算出轻儿子的答案,期间都清空辅助数组,只保留答案。重儿子我们就不用清空辅助数组了,算出答案,这样子节点的答案都算出来了,就差父亲的了。我们把轻儿子的答案直接合并到重儿子的答案,就得到了父节点的答案。

放在这道题,每次轻儿子计算后清空记小球出现时间的那个数组,然后算重儿子,之后将轻儿子上的线段树都启发式合并到重儿子的线段树上,父节点直接继承这棵线段树,不断递归这个过程。这个做法复杂度是\(O(Nlog^2N)\)级别的,就可以过了。

真难写,调了好久。

这道题学到的东西蛮多的。首先我之前一直没看动态开点线段树,学前置芝士的时候直接学了。然后我经常口胡启发式合并,但从来没写过,码力不足,这次也真的写了一次。第三,好久没写过这么复杂的题了,属实锻炼码力。

说句闲话,我之前一直WA40,一看Query没return,默认return了堆栈的top。败RP啊。。。。

#include <bits/stdc++.h>

const int N = 1e5 + 233;
int n, k[N], m, q, ecnt, head[N], color[N], disc[N];
int early[N], stk[N * 200], tp, ans[N], root[N];
struct Edge {
    int to, nxt;
} e[N << 1];
int ls[N * 200], rs[N * 200], siz[N * 200], val[N * 200];
std::vector<int> op[N];

inline void add_edge(int f, int to) {
    e[++ecnt] = {to, head[f]}, head[f] = ecnt;
}

int fa[N], sz[N], son[N];

void dfs1(int x, int f) {
    fa[x] = f, sz[x] = 1 + (int) op[x].size();
    for (int i = head[x], y = e[i].to; i; i = e[i].nxt, y = e[i].to) {
        if (y != f) {
            dfs1(y, x);
            sz[x] += sz[y];
            if (sz[y] > sz[son[x]]) son[x] = y;
        }
    }
}

int tot, rbin[N * 200], rbin_top;

int new_node() {
    if (tot + 1 < N * 200) return ++tot;
    else return rbin[rbin_top--];
}

void del_node(int x) {
    rbin[++rbin_top] = x;
}

void pushup(int p) {
    siz[p] = siz[ls[p]] + siz[rs[p]];
    val[p] = val[ls[p]] + val[rs[p]];
}

void change(int &p, int L, int R, int x, int y) {
    if (!p) p = new_node();
    if (L == R) {
        siz[p] = 1;
        if (y != -1) val[p] = y;
        return;
    }
    int mid = (L + R) >> 1;
    if (x <= mid) change(ls[p], L, mid, x, y);
    else change(rs[p], mid + 1, R, x, y);
    pushup(p);
}

void merge(int p, int &root, int L, int R) {
    if (!p) return;
    if (L == R) {
        if (early[color[L]] > L) {
            change(root, 1, m, early[color[L]], 0);
            change(root, 1, m, L, 1);
            early[color[L]] = L;
        } else if (early[color[L]] == 0) {
            change(root, 1, m, L, 1);
            early[color[L]] = L;
            stk[++tp] = color[L];
        } else {
            change(root, 1, m, L, -1);
        }
        del_node(p);
    }
    int mid = (L + R) >> 1;
    merge(ls[p], root, L, mid);
    merge(rs[p], root, mid + 1, R);
}

int query(int p, int L, int R, int bucket) {
    if (!bucket || !p) return 0;
    if (siz[p] <= bucket) return val[p];
    int mid = (L + R) >> 1, ret = 0;
    if (siz[ls[p]] < bucket) {
        ret += val[ls[p]];
        ret += query(rs[p], mid + 1, R, bucket - siz[ls[p]]);
    } else {
        ret += query(ls[p], L, mid, bucket);
    }
    return ret;
}

void clear() {
    while (tp > 0) early[stk[tp--]] = 0;
}

void solve(int x) {
    for (int i = head[x], y = e[i].to; i; i = e[i].nxt, y = e[i].to)
        if (y != son[x] && y != fa[x]) solve(y), clear(); //先解决轻儿子的答案
    if (son[x]) solve(son[x]);
    root[x] = root[son[x]];
    for (unsigned int i = 0; i < op[x].size(); i++) {
        int co = color[op[x][i]];
        if (early[co] == 0) {
            change(root[x], 1, m, op[x][i], 1);
            early[co] = op[x][i];
            stk[++tp] = co; 
        } else if (early[co] > op[x][i]) {
            change(root[x], 1, m, early[co], 0);
            change(root[x], 1, m, op[x][i], 1);
            early[co] = op[x][i];
        } else {
            change(root[x], 1, m, op[x][i], -1);
        }
    }
    for (int i = head[x], y = e[i].to; i; i = e[i].nxt, y = e[i].to) {
        if (y != fa[x] && y != son[x]) {
            merge(root[y], root[x], 1, m);
        }
    }
    ans[x] = query(root[x], 1, m, k[x]);
}

signed main() {
    scanf("%d", &n);
    for (int i = 1, x, y; i <= n - 1; i++)
        scanf("%d%d", &x, &y), add_edge(x, y), add_edge(y, x);
    for (int i = 1; i <= n; i++)
        scanf("%d", k + i);
    scanf("%d", &m);
    for (int i = 1, x, y; i <= m; i++)
        scanf("%d%d", &x, &y), disc[i] = color[i] = y, op[x].push_back(i);
    std::sort(disc + 1, disc + 1 + m);
    int QwQ = std::unique(disc + 1, disc + 1 + m) - (disc + 1);
    for (int i = 1; i <= m; i++)
        color[i] = std::lower_bound(disc + 1, disc + 1 + QwQ, color[i]) - disc;
    scanf("%d", &q);
    dfs1(1, 0);
    solve(1);
    for (int i = 1, x; i <= q; i++)
        scanf("%d", &x), printf("%d\n", ans[x]);
    return 0;
}