BZOJ2631——tree

时间:2022-03-09 12:04:17

1、题目大意:bzoj1798的lct版本

2、分析:这个把线段树改成splay就好

#include <stack>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;
#define LL long long
namespace LinkCutTree{
    struct Node{
        Node *ch[2], *fa;
        LL sum, num;
        LL size;
        bool rev;
        LL mul, plu; 

        inline int which();

        inline void reverse(){
            if(this) rev ^= 1;
        }

        inline void pd();

        inline void maintain(){
            sum = (num + ch[0] -> sum + ch[1] -> sum) % 51061;
            size = (1 + ch[0] -> size + ch[1] -> size) % 51061;
        }

        Node();
    } *null = new Node, tree[100010], *pos[100010];

    Node::Node(){
        num = sum = 1;
        rev = false;
        ch[0] = ch[1] = fa = null;
        mul = 1;
        plu = 0;
        size = 1;
    }

    inline void Node::pd(){
            if(rev){
                swap(ch[0], ch[1]);
                ch[0] -> reverse();
                ch[1] -> reverse();
                rev = false;
            }
            if(ch[0] != null){
                ch[0] -> mul *= mul;
                ch[0] -> plu *= mul;
                ch[0] -> plu += plu;
                ch[0] -> num *= mul;
                ch[0] -> num += plu;
                ch[0] -> sum *= mul;
                ch[0] -> sum += plu * ch[0] -> size;
                ch[0] -> mul %= 51061;
                ch[0] -> plu %= 51061;
                ch[0] -> num %= 51061;
                ch[0] -> sum %= 51061;
            }
            if(ch[1] != null){
                ch[1] -> mul *= mul;
                ch[1] -> plu *= mul;
                ch[1] -> plu += plu;
                ch[1] -> num *= mul;
                ch[1] -> num += plu;
                ch[1] -> sum *= mul;
                ch[1] -> sum += plu * ch[1] -> size;
                ch[1] -> mul %= 51061;
                ch[1] -> plu %= 51061;
                ch[1] -> num %= 51061;
                ch[1] -> sum %= 51061;
            }
            mul = 1;
            plu = 0;
    }

    inline int Node::which(){
        if(fa == null || (this != fa -> ch[0] && this != fa -> ch[1])) return -1;
        return this == fa -> ch[1];
    }

    inline void rotate(Node *o){
        Node *p = o -> fa;
        int l = o -> which(), r = l ^ 1;
        o -> fa = p -> fa;
        if(p -> which() != -1) p -> fa -> ch[p -> which()] = o;
        p -> ch[l] = o -> ch[r];
        if(o -> ch[r]) o -> ch[r] -> fa = p;
        o -> ch[r] = p; p -> fa = o;
        o -> ch[r] -> maintain();
        o -> maintain();
    }

    inline void splay(Node *o){
        static stack<Node*> st;
        if(!o) return;
        Node *p = o;
        while(1){
            st.push(p);
            if(p -> which() == -1) break;
            p = p -> fa;
        }
        while(!st.empty()){
            st.top() -> pd(); st.pop();
        }

        while(o -> which() != -1){
            p = o -> fa;
            if(p -> which() != -1){
                if(p -> which() ^ o -> which()) rotate(o);
                else rotate(p);
            }
            rotate(o);
        }
    }

    inline void Access(Node *o){
        Node *y = null;
        while(o != null){
            splay(o);
            o -> ch[1] = y;
            o -> maintain();
            y = o; o = o -> fa;
        }
    }

    inline void MovetoRoot(Node *o){
        Access(o);
        splay(o);
        o -> reverse();
    }

    inline Node* FindRoot(Node *o){
        Access(o);
        splay(o);
        while(o -> ch[0] != null) o = o -> ch[0];
        return o;
    }

    inline void Link(Node *x, Node *y){
        MovetoRoot(x);
        x -> fa = y;
    }

    inline void Cut(Node *x, Node *y){
        MovetoRoot(x);
        Access(y);
        splay(y);
        y -> ch[0] = x -> fa = null;
        y -> maintain();
    }
}
int main(){
    using namespace LinkCutTree;
    null -> mul = 1;
    null -> size = 0;
    null -> plu = 0;
    null -> sum = 0;
    null -> num = 0;
    null -> ch[0] = null -> ch[1] = null -> fa = NULL;
    int n, q;
    scanf("%d%d", &n, &q);
    for(int i = 1; i <= n; i ++) pos[i] = &tree[i];
    for(int i = 1; i < n; i ++){
        int u, v;
        scanf("%d%d", &u, &v);
        Link(pos[u], pos[v]);
    }
    char op[5];
    int x1, y1, x2, y2, c;
    while(q --){
        scanf("%s", op);
        if(op[0] == '+'){
            scanf("%d%d%d", &x1, &y1, &c);
            MovetoRoot(pos[x1]);
            Access(pos[y1]);
            splay(pos[y1]);
            pos[y1] -> num += c;
            pos[y1] -> num %= 51061;
            pos[y1] -> sum += pos[y1] -> size * c;
            pos[y1] -> sum %= 51061;
            pos[y1] -> plu += c;
            pos[y1] -> plu %= 51061;
        }
        else if(op[0] == '-'){
            scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
            Cut(pos[x1], pos[y1]);
            Link(pos[x2], pos[y2]);
        }
        else if(op[0] == '*'){
            scanf("%d%d%d", &x1, &y1, &c);
            MovetoRoot(pos[x1]);
            Access(pos[y1]);
            splay(pos[y1]);
            pos[y1] -> num *= c;
            pos[y1] -> num %= 51061;
            pos[y1] -> sum *= c;
            pos[y1] -> sum %= 51061;
            pos[y1] -> mul *= c;
            pos[y1] -> mul %= 51061;
            pos[y1] -> plu *= c;
            pos[y1] -> plu %= 51061;
        }
        else{
            scanf("%d%d", &x1, &y1);
            MovetoRoot(pos[x1]);
            Access(pos[y1]);
            splay(pos[y1]);
            pos[y1] -> sum %= 51061;
            printf("%lld\n", pos[y1] -> sum);
        }
    }
    return 0;
}