SPOJ QTREE6 Query on a tree VI 树链剖分

时间:2021-03-16 12:35:17

题意:

给出一棵含有\(n(1 \leq n \leq 10^5)\)个节点的树,每个顶点只有两种颜色:黑色和白色。
一开始所有的点都是黑色,下面有两种共\(m(1 \leq n \leq 10^5)\)次操作:

  • \(0 \, u\)表示查询\(u\)所在的连通块的大小,相邻两个点颜色相同则属于一个连通块。
  • \(0 \, u\)表示翻转\(u\)的颜色,即黑点变白点,白点变黑点。

分析:

参考CodeChef上的题解

首先将这棵树剖分成轻重链。
然后我们维护两个值:\(White(u)\)\(Black(u)\)
\(White(u)\)表示当\(u\)是白点时(这里我们不关心\(u\)真正的颜色),以\(u\)为根的子树中,\(u\)所在的连通块的大小。
同理,\(Black(u)\)表示\(u\)是黑点时(这里我们不关心\(u\)真正的颜色),以\(u\)为根的子树中,\(u\)所在的连通块的大小。

  • 对于查询操作\(0 \, \, u\)
    \(u\)往上走,走到深度最小的与\(u\)同色的节点\(v\),那么答案就是\(White(v)\)\(Black(v)\)

  • 对于修改操作\(1, \, \, u\)
    由于对称性,不妨假设\(u\)从白点变为黑点。
    \(u\)的父节点往上走,走到第一个黑点\(v_1\),设路径\(path_1\)\(fa(u) \to v\),然后将\(path_1\)上所有点的\(White\)减去\(White(u)\)
    同样地,从\(u\)的父节点往上走,走到第一个白点\(v_2\),设路径\(path_2\)\(fa(u) \to v_2\),然后将\(path_2\)上所有点的\(Black\)加上\(Black(u)\)
    这里修改操作是成段更新的,所以要用线段树维护一下。

接下来还要解决一个问题:如何快速找到上面说的深度最浅的同色点和遇到的第一个黑/白点
继续用线段树维护一个\(fir0\)\(fir1\),表示该区间从右往左遇到的第一个白点和黑点,对应到树上的链就是从下往上。
这样就解决了第二个问题,其实第一个问题的答案也可以间接得到。
要找深度最前的同色点,就是第一个异色点的那个子节点。
父节点只有一个,但子节点又如何确定呢?
注意到,我们查询时是在剖出来的链上一条一条往上“跳”的,所以在同一条链上一个点的子节点就在线段树中它位置的右边相邻的那个点。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int maxn = 100000 + 10;
const int maxnode = maxn * 4;

struct Edge
{
    int v, nxt;
    Edge() {}
    Edge(int v, int nxt): v(v), nxt(nxt) {}
};

int ecnt, head[maxn];
Edge edges[maxn * 2];

void AddEdge(int u, int v) {
    edges[ecnt] = Edge(v, head[u]);
    head[u] = ecnt++;
}

int n;

int fa[maxn], sz[maxn], son[maxn], dep[maxn];

void dfs(int u) {
    sz[u] = 1; son[u] = 0;
    for(int i = head[u]; ~i; i = edges[i].nxt) {
        int v = edges[i].v;
        if(v == fa[u]) continue;
        fa[v] = u;
        dep[v] = dep[u] + 1;
        dfs(v);
        sz[u] += sz[v];
        if(sz[v] > sz[son[u]]) son[u] = v;
    }
}

int tot, top[maxn], id[maxn], pos[maxn];

void dfs2(int u, int tp) {
    top[u] = tp;
    id[u] = ++tot;
    pos[tot] = u;
    if(!son[u]) return;
    dfs2(son[u], tp);
    for(int i = head[u]; ~i; i = edges[i].nxt) {
        int v = edges[i].v;
        if(v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}

int color[maxn], addv[2][maxnode];
int fir[2][maxnode];

void build(int o, int L, int R) {
    if(L == R) {
        fir[0][o] = 0; fir[1][o] = L;
        addv[0][o] = 1; addv[1][o] = sz[pos[L]];
        return;
    }
    int M = (L + R) / 2;
    build(o<<1, L, M);
    build(o<<1|1, M+1, R);

    fir[0][o] = 0;
    fir[1][o] = R;
}

void pushdown(int o) {
    for(int i = 0; i < 2; i++) {
        int& t = addv[i][o];
        if(!t) continue;
        addv[i][o<<1] += t;
        addv[i][o<<1|1] += t;
        t = 0;
    }
}

void update(int o, int L, int R, int qL, int qR, int col, int v) {
    if(qL <= L && R <= qR) {
        addv[col][o] += v;
        return;
    }
    pushdown(o);
    int M = (L + R) / 2;
    if(qL <= M) update(o<<1, L, M, qL, qR, col, v);
    if(qR > M) update(o<<1|1, M+1, R, qL, qR, col, v);
}

void UPDATE(int u, int v, int col, int val) {
    while(top[u] != top[v]) {
        update(1, 1, n, id[top[u]], id[u], col, val);
        u = fa[top[u]];
    }
    update(1, 1, n, id[v], id[u], col, val);
}

int querysize(int o, int L, int R, int p, int col) {
    if(L == R) return addv[col][o];
    pushdown(o);
    int M = (L + R) / 2;
    if(p <= M) return querysize(o<<1, L, M, p, col);
    else return querysize(o<<1|1, M+1, R, p, col);
}

int queryfir(int o, int L, int R, int qL, int qR, int col) {
    if(qL <= L && R <= qR) return fir[col][o];
    int ans = 0;
    int M = (L + R) / 2;
    if(qR > M) ans = queryfir(o<<1|1, M+1, R, qL, qR, col);
    if(ans) return ans;
    if(qL <= M) ans = queryfir(o<<1, L, M, qL, qR, col);
    return ans;
}

int QueryFir(int u, int col) {
    int ans = 0;
    int t = top[u];
    while(t != 1) {
        ans = queryfir(1, 1, n, id[t], id[u], col);
        if(ans) return ans;
        u = fa[t]; t = top[u];
    }
    return queryfir(1, 1, n, 1, id[u], col);

}

int QuerySuf(int u, int col) {
    int ans = id[u];
    while(top[u] != 1) {
        int t = queryfir(1, 1, n, id[top[u]], id[u], col ^ 1);
        if(t) return t == id[u] ? ans : t + 1;
        ans = id[top[u]];
        u = fa[top[u]];
    }
    int t = queryfir(1, 1, n, 1, id[u], col ^ 1);
    if(!t) return 1;
    return t == id[u] ? ans : t + 1;
}

void change(int o, int L, int R, int p) {
    if(L == R) {
        int u = pos[L];
        int& c = color[u];
        c ^= 1;
        fir[c][o] = L;
        fir[c ^ 1][o] = 0;
        return;
    }
    int M = (L + R) / 2, lenr = R - M;
    if(p <= M) change(o<<1, L, M, p);
    else change(o<<1|1, M+1, R, p);
    fir[0][o] = fir[0][o<<1|1] ? fir[0][o<<1|1] : fir[0][o<<1];
    fir[1][o] = fir[1][o<<1|1] ? fir[1][o<<1|1] : fir[1][o<<1];
}

int main()
{
    scanf("%d", &n);
    for(int i = 1; i <= n; i++) color[i] = 1;

    ecnt = 0;
    memset(head, -1, sizeof(head));
    for(int i = 1; i < n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        AddEdge(u, v);
        AddEdge(v, u);
    }

    dfs(1);
    tot = 0;
    dfs2(1, 1);

    build(1, 1, n);
    int _; scanf("%d", &_);
    while(_--) {
        int op, u; scanf("%d%d", &op, &u);
        if(op == 0) {
            int v = pos[QuerySuf(u, color[u])];
            int ans = querysize(1, 1, n, id[v], color[v]);
            printf("%d\n", ans);
        } else {
            if(u != 1) {
                int v = pos[QueryFir(fa[u], color[u] ^ 1)];
                if(!v) v = 1;
                int sub = querysize(1, 1, n, id[u], color[u]);
                UPDATE(fa[u], v, color[u], -sub);

                v = pos[QueryFir(fa[u], color[u])];
                if(!v) v = 1;
                int add = querysize(1, 1, n, id[u], color[u] ^ 1);
                UPDATE(fa[u], v, color[u] ^ 1, add);
            }

            change(1, 1, n, id[u]);
        }
    }

    return 0;
}