理解线段树这一篇文章就够啦!

时间:2021-11-20 00:39:37

TODO:

前言

本文中,若无特殊说明,数列下标均从 \(1\) 开始

由于本人实力有限,线段树更高级的拓展暂不做考虑

引入

什么是线段树

线段树\(Segment\ Tree\))是一种二叉搜索树,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶子节点,由于每一个节点都表示一个区间(或者说是线段),所以也被认为是一颗区间树。

用途

线段树常用于动态维护区间信息

例题

P3374 【模板】树状数组 1 - 洛谷

题目简述:对数列进行单点修改以及区间求和

常规解法

单点修改的时间复杂度为 \(O(1)\)

区间求和的时间复杂度为 \(O(n)\)

\(m\) 次操作,则总时间复杂度为 \(O(n\times m)\)

点击查看代码
import java.io.*;

public class Main {
    static StreamTokenizer in = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));

    static int get() throws IOException {
        in.nextToken();
        return (int) in.nval;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = get(), m = get();
        int[] a = new int[n];
        for (int i = 0; i < n; ++i) a[i] = get();
        while (m-- != 0) {
            int command = get(), x = get(), y = get();
            if (command == 1) {
                a[x - 1] += y;
            } else {
                int sum = 0;
                for (int i = x - 1; i < y; i++) sum += a[i];
                out.println(sum);
            }
        }
        out.close();
    }
}

前缀和解法

区间求和通过前缀和优化,但单点修改的时候需要修改前缀和数组

单点修改的时间复杂度为 \(O(n)\)

区间求和的时间复杂度为 \(O(1)\)

\(m\) 次操作,则总时间复杂度为 \(O(n\times m)\)

点击查看代码
import java.io.*;

public class Main {
    static StreamTokenizer in = new StreamTokenizer(new BufferedReader(new InputStreamReader(System.in)));

    static int get() throws IOException {
        in.nextToken();
        return (int) in.nval;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = get(), m = get();
        int[] sum = new int[n + 1];
        for (int i = 1; i <= n; ++i) sum[i] = sum[i - 1] + get();
        while (m-- != 0) {
            int command = get(), x = get(), y = get();
            if (command == 1) {
                for (int i = x; i <= n; ++i) sum[i] += y;
            } else {
                System.out.println(sum[y] - sum[x - 1]);
            }
        }
        out.close();
    }
}

线段树解法

线段树的思想

线段树是一种基于分治思想的一种数据结构,它通过不断将区间拆分合并来实现区间改查

线段树的形态结构

我们规定:若当前节点的标号为 \(x\),则其左儿子标号为 \(2\times x\),右儿子标号为 \(2\times x+1\),叶子节点的管辖区间长度为 \(1\)

一颗管理数组长度为 \(7\) 的线段树基本结构如下,其中蓝色圆中数据代表节点标号,绿色矩形内数据代表该节点的管辖区间。

理解线段树这一篇文章就够啦!

线段树的存储

对于二叉树的存储,通常选择使用指针存储,但在算法竞赛中,常选择堆式存储(静态数组存储)。

选择堆式存储时,我们需要确定数组空间大小。

一颗管理数组长度为 \(n\) 的线段树的节点个数为 \(2\times n -1\)

证明如下:

设一颗线段树的度数为 \(0\) 的节点个数为 \(N_0\),度数为 \(1\) 的节点个数为 \(N_1\),度数为 \(2\) 的节点个数为 \(N_2\)

由线段树的定义,叶子节点的管辖区间长度为 \(1\),则叶子节点的个数为 \(n\),即 \(N_0=n\)

每个节点代表一个区间,如果一个区间能划分,则一定划分为 \(2\) 个区间,因此 \(N_1=0\)

二叉树的性质:\(N_2=N_0-1\)

因此,一颗管理数组长度为 \(n\) 的线段树的节点个数为 \(N=N_0+N_1+N_2=2\times n -1\)

那是否静态数组空间就只需要 \(2\times n-1\) 呢?

线段树的形态结构中的图表示并非如此。

因为有些节点是空的,所以最后一个节点标号一定与满二叉树相同。

前置结论:对于高度为 \(h\) 的满二叉树,最后一层有 \(2^{h-1}\) 个节点,总共有 \(2^h-1\) 个节点,则除最后一层外的节点总数有 \(2^{h}-1-2^{h-1}=2^{h-1}-1\),与最后一层节点个数对比,得:$ 除最后一层外的节点总数 = 最后一层的节点个数 -1 $

理解线段树这一篇文章就够啦!

线段树所需要的节点数量,分两种情况来讨论:

  • 如果 \(n\) 恰好是 \(2\)\(k\) 次幂,由于线段树最后一层的叶子节点存储的是数组元素本身,最后一层的节点数就是 \(n\),则前面所有层的节点数为 \(n-1\),那么总节点数为 \(2\times n -1\)

  • 如果 \(n\) 不是 \(2\)\(k\) 次幂,即 \(n=2^k+x\) 其中 \(x>0\),则需要新开辟一层来存储,等同于 \(2^{k+1}\) 的情况,则总结点个数为 \(4n-4x-1\),最大不超过 \(4n-5\)

理解线段树这一篇文章就够啦!

又由于我们让数据从下标为 \(1\) 开始存储,得出如下结论:

  • \(n\)\(2\) 的正整数幂时,所需空间大小为 \(2\times n\)
  • \(n\) 不是 \(2\) 的正整数幂时,所需空间大小为 \(4\times n-4\)

为了方便,我们通常选择开辟 \(4\times n\) 的空间

建树

以下 \(tree[i]\) 代表 \(i\) 号节点所存储的数据,\(data[i]\) 代表原数组数据

每个节点 \(p\) 的左右子节点的编号分别为 \(2p\)\(2p+1\)

要求得某一个节点的值,需要得到两个子节点的值,再将其合并,采用递归的形式建树,其中合并操作单独记为一个函数(修改操作时用)

递归的终止条件为达到叶子节点,即节点管辖区间长度为 \(1\),不能再划分了,此时 \(l=r\)

所需函数参数如下:

  1. 当前节点的编号,即 \(tree\) 数组中的索引 \(o\)
  2. 该节点所管辖区间的左边界 \(l\)
  3. 该节点所管辖区间的右边界 \(r\)
// 合并x和y两个节点的区间值,并赋给o节点
public void pushUp(int o, int x, int y) {
    tree[o] = tree[x] + tree[y];
}

/**
 * @param o     当前节点编号
 * @param l     当前节点管辖区间的左边界
 * @param r     当前节点管辖区间的右边界
 * @param data  原数组数据
 */
public void build(int o, int l, int r, int[] data) {
    // 到达叶子节点(管辖区间长度为1)
    if (l == r) {
        tree[o] = data[l];
        return;
    }
    // mid为中间值,用于划分区间
    // x 为左儿子编号
    // y 为右儿子编号
    int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 构建左子树,区间为[l,mid]
    build(x, l, mid, data);
    // 构建右子树,区间为[mid+1,r]
    build(y, mid + 1, r, data);
    // 合并两个子区间的数据
    pushUp(o, x, y);
}

单点修改

修改元素时,需要先找到待修改的最底层的数据(叶子节点),修改后再逐步上传数据

单点修改的基本步骤如下:

  1. 若待修改元素位于 \([l,mid]\) 区间,则递归修改左子树部分

    若待修改元素位于 \([mid+1,r]\) 区间,则递归修改右子树部分

  2. 合并两个子区间的数据

所需函数参数如下:

  1. 待修改元素位置 \(index\)
  2. 修改后元素(或增量)的数据 \(val\)
  3. 当前节点的编号,即 \(tree\) 数组中的索引 \(o\)
  4. 该节点所管辖区间的左边界 \(l\)
  5. 该节点所管辖区间的右边界 \(r\)
/**
 * @param index 待修改元素位置
 * @param val   修改后元素(或增量)的数据
 * @param o     当前节点编号
 * @param l     当前节点管辖区间的左边界
 * @param r     当前节点管辖区间的右边界
 */
public void updateOne(final int index, final int val, int o, int l, int r) {
    // 到达叶子节点(管辖区间长度为1)
    if (l == r) {
        tree[o] += val;
        return;
    }
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 修改元素在左边区间
    if (index <= mid) updateOne(index, val, x, l, mid);
    // 修改元素在右边区间
    else updateOne(index, val, y, mid + 1, r);
    // 合并两个子区间的数据
    pushUp(o, x, y);
}

单点查询

与单点修改相同,只是不需要进行子区间数据合并了(因为没有变)

/**
 * @param index 待查找元素位置
 * @param o     当前节点编号
 * @param l     当前节点管辖区间的左边界
 * @param r     当前节点管辖区间的右边界
 * @return index位置处的值
 */
public int queryOne(final int index, int o, int l, int r) {
    if (l == r) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 如果查询元素在左边区间
    if (index <= mid) return queryOne(index, x, l, mid);
    // 否则在右边区间
    return queryOne(index, y, mid + 1, r);
}

区间查询

求某一个区间的值,对于线段树就是分解线段树区间,直至该区间在查询区间内部,此时该区间的值已经获得,不需要再分解了

区间查询的分解区间步骤如下:

  • 如果左子树包含查询区间,即 \(queryLeft\le mid\),则查询左子树
  • 如果右子树包含查询区间,即 \(queryRight>mid\),则查询右子树

所需函数参数如下:

  1. 待查找区间左边界 \(left\)
  2. 待查找区间左边界 \(right\)
  3. 当前节点的编号,即 \(tree\) 数组中的索引 \(o\)
  4. 该节点所管辖区间的左边界 \(l\)
  5. 该节点所管辖区间的右边界 \(r\)
/**
 * @param left  待查找区间左边界
 * @param right 待查找区间右边界
 * @param o     当前节点编号
 * @param l     当前节点管辖区间的左边界
 * @param r     当前节点管辖区间的右边界
 * @return 区间[left, right]的值
 */
public int queryRange(final int left, final int right, int o, int l, int r) {
    // 如果线段树区间在查询区间内部,这一区间已经为答案了,不需要再分解了
    if (left <= l && r <= right) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    int ans = 0;
    // 如果左子树包含查询区间
    if (left <= mid) ans += queryRange(left, right, x, l, mid);
    // 如果右子树包含查询区间
    if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
    return ans;
}

复杂度分析

空间复杂度为 \(O(4n)=O(n)\)

单点修改、单点查询、区间查询操作的时间复杂度均为 \(O(\log n)\)

建树的时间复杂度为树的节点个数 \(O(2\times n- 1)=O(n)\)

Code

点击查看代码
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read();
        int[] a = new int[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = read();
        SegmentTree seg = new SegmentTree(a, n);
        while (m-- != 0) {
            int command = read();
            if (command == 1) {
                int x = read(), k = read();
                seg.updateOne(x, k, 1, 1, n);
            } else {
                int x = read(), y = read();
                out.println(seg.queryRange(x, y, 1, 1, n));
            }
        }
        out.close();
    }
}

class SegmentTree {
    int[] tree;
    int n;

    private void pushUp(int o, int x, int y) {
        tree[o] = tree[x] + tree[y];
    }

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l];
            return;
        }
        int mid = l + r >> 1, x = o << 1, y = x | 1;
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new int[n << 2];
    }

    // 请保证数组数据下标从1开始
    public SegmentTree(int[] val, int _n) {
        // assert(val.length >= _n);
        this(_n);
        build(1, 1, n, val);
    }

    public void updateOne(final int index, final int val, int o, int l, int r) {
        if (l == r) {
            tree[o] += val;
            return;
        }
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (index <= mid) updateOne(index, val, x, l, mid);
        else updateOne(index, val, y, mid + 1, r);
        pushUp(o, x, y);
    }

    public int queryRange(final int left, final int right, int o, int l, int r) {
        if (left <= l && r <= right) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        int ans = 0;
        if (left <= mid) ans += queryRange(left, right, x, l, mid);
        if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
        return ans;
    }

    public int queryOne(final int index, int o, int l, int r) {
        if (l == r) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (index <= mid) return queryOne(index, x, l, mid);
        return queryOne(index, y, mid + 1, r);
    }
}

进阶

区间修改+单点查询

P3368 【模板】树状数组 2 - 洛谷

树状数组相同,可以使用差分的方式,将区间修改变为两次单点修改,本文对于该方法暂不做讨论

区间修改的分解区间步骤与区间查询类似:

  • 如果左子树包含修改区间,即 \(queryLeft\le mid\),则修改左子树
  • 如果右子树包含修改区间,即 \(queryRight>mid\),则修改右子树

在上述操作之后,合并两个子区间的数据

所需函数参数如下:

  1. 待修改区间左边界 \(left\)
  2. 待修改区间左边界 \(right\)
  3. 修改后元素(或增量)的数据 \(val\)
  4. 当前节点的编号,即 \(tree\) 数组中的索引 \(o\)
  5. 该节点所管辖区间的左边界 \(l\)
  6. 该节点所管辖区间的右边界 \(r\)
/**
 * @param left  待修改区间左边界
 * @param right 待修改区间右边界
 * @param val   修改后元素(或增量)的数据
 * @param o     当前节点编号
 * @param l     当前节点管辖区间的左边界
 * @param r     当前节点管辖区间的右边界
 */
public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
    // 到达叶子节点(管辖区间长度为1)
    if (l == r) {
        tree[o] += val;
        return;
    }
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 如果查询区间全在左边
    if (left <= mid) updateRange(left, right, val, x, l, mid);
    // 如果查询区间全在右边
    if (right > mid) updateRange(left, right, val, y, mid + 1, r);
    // 合并两个子区间的数据
    pushUp(o, x, y);
}

可以发现区间修改时的时间复杂度很高,因为需要对 \([left,right]\) 区间内的每一个叶子都修改,时间复杂度与修改路径上的节点个数有关,最坏时间复杂度为 \(O(2\times n-1)=O(n)\)

懒惰标记

为了降低区间修改的时间复杂度,让区间修改的形式与区间查询的形式相同(即直接修改区间,不修改单个的值),每个节点上多携带懒惰标记这个信息

原理:不用的话我就不修改,只在用的时候(查询)修改

标记:本区间已经被更新过了,但是子区间却没有被更新过,被更新的信息是什么。

当路过这个节点时,加上这个标记的值

public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
    if (left <= l && r <= right) {
        tree[o] += val * (r - l + 1);
        // 携带上val这个信息,表明该子树均未修改
        tag[o] += val;
        return;
    }
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    if (left <= mid) updateRange(left, right, val, x, l, mid);
    if (right > mid) updateRange(left, right, val, y, mid + 1, r);
    pushUp(o, x, y);
}
public int queryOne(final int index, int o, int l, int r) {
    if (l == r) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 加上路径上的懒惰标记值
    if (index <= mid) return tag[o] + queryOne(index, x, l, mid);
    return tag[o] + queryOne(index, y, mid + 1, r);
}

Code

点击查看代码
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read();
        int[] a = new int[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = read();
        SegmentTree seg = new SegmentTree(a, n);
        while (m-- != 0) {
            int command = read();
            if (command == 1) {
                int x = read(), y = read(), k = read();
                seg.updateRange(x, y, k, 1, 1, n);
            } else {
                int x = read();
                out.println(seg.queryOne(x, 1, 1, n));
            }
        }
        out.close();
    }
}

class SegmentTree {
    int[] tree, tag;
    int n;

    private void pushUp(int o, int x, int y) {
        tree[o] = tree[x] + tree[y];
    }

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l];
            return;
        }
        int mid = l + r >> 1, x = o << 1, y = x | 1;
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new int[n << 2];
        tag = new int[n << 2];
    }

    // 请保证数组数据下标从1开始
    public SegmentTree(int[] val, int _n) {
        // assert(val.length >= _n);
        this(_n);
        build(1, 1, n, val);
    }

    public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
        if (left <= l && r <= right) {
            tree[o] += val * (r - l + 1);
            tag[o] += val;
            return;
        }
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (left <= mid) updateRange(left, right, val, x, l, mid);
        if (right > mid) updateRange(left, right, val, y, mid + 1, r);
        pushUp(o, x, y);
    }

    public int queryOne(final int index, int o, int l, int r) {
        if (l == r) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (index <= mid) return tag[o] + queryOne(index, x, l, mid);
        return tag[o] + queryOne(index, y, mid + 1, r);
    }
}

区间修改+区间查询

P3372 【模板】线段树 1 - 洛谷

当进行区间修改及区间查询或多种复杂操作时,可能会觉得直接套用例题中的区间查询和进阶中的区间修改就行。

但事实上不是这样的,一个错误代码如下:

private void pushUp(int o, int x, int y) {
    tree[o] = tree[x] + tree[y];
}
public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
    if (left <= l && r <= right) {
        tree[o] += val * (r - l + 1);
        tag[o] += val;
        return;
    }
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    if (left <= mid) updateRange(left, right, val, x, l, mid);
    if (right > mid) updateRange(left, right, val, y, mid + 1, r);
    pushUp(o, x, y);
}
public int queryRange(final int left, final int right, int o, int l, int r) {
    if (left <= l && r <= right) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 算上 该区间 的tag标记
    int ans = tag[o] * (Math.min(r, right) - Math.max(l, left) + 1);
    if (left <= mid) ans += queryRange(left, right, x, l, mid);
    if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
    return ans;
}

\(\{0,0,0,0\}\) 的数组举例,初始线段树图示如下:

理解线段树这一篇文章就够啦!

\(1\sim4\) 区间 \(+1\),得:

理解线段树这一篇文章就够啦!

\(1\sim2\) 区间 \(+1\) 得:

理解线段树这一篇文章就够啦!

此时发现不对了,\(1\) 号节点的数据被修改了,结果不正确

原因如下:

  • 第一次修改时,对 \(1\sim4\) 的区间 \(+1\) 并没有传到子节点,子节点的值没有发生改变

  • 第二次修改时,对 \(1\sim2\) 的区间 \(+1\)后,调用 \(pushUp\) 数据上传,\(1\) 号节点的数据就不正确了

下面有两种方式解决该问题

标记永久化

标记永久化:在修改时修改路径上被影响的节点,在询问时累加路径上的标记

区间修改:将路径上的影响计算到线段树的 \(data\)

区间查询:累加查询路径上的 \(tag\)(有效区间内的)

public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
    if (left <= l && r <= right) {
        tree[o] += val * (r - l + 1);
        tag[o] += val;
        return;
    }
    // 将后续修改的影响计算到当前节点中
    tree[o] += val * (Math.min(r, right) - Math.max(l, left) + 1);
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    if (left <= mid) updateRange(left, right, val, x, l, mid);
    if (right > mid) updateRange(left, right, val, y, mid + 1, r);
    // 不用pushUp操作
}
public long queryRange(final int left, final int right, int o, int l, int r) {
    if (left <= l && r <= right) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 累加路径上的tag标记(有效区间内的)
    long ans = tag[o] * (Math.min(r, right) - Math.max(l, left) + 1);
    if (left <= mid) ans += queryRange(left, right, x, l, mid);
    if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
    return ans;
}

Code

点击查看代码
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read();
        int[] a = new int[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = read();
        SegmentTree seg = new SegmentTree(a, n);
        while (m-- != 0) {
            int command = read();
            if (command == 1) {
                int x = read(), y = read(), k = read();
                seg.updateRange(x, y, k, 1, 1, n);
            } else {
                int x = read(), y = read();
                out.println(seg.queryRange(x, y, 1, 1, n));
            }
        }
        out.close();
    }
}

class SegmentTree {
    long[] tree, tag;
    int n;

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l];
            return;
        }
        int mid = l + r >> 1, x = o << 1, y = x | 1;
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        tree[o] = tree[x] + tree[y];
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new long[n << 2];
        tag = new long[n << 2];
    }

    // 请保证数组数据下标从1开始
    public SegmentTree(int[] val, int _n) {
        // assert(val.length >= _n);
        this(_n);
        build(1, 1, n, val);
    }

    public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
        if (left <= l && r <= right) {
            tree[o] += val * (r - l + 1);
            tag[o] += val;
            return;
        }
        // 将后续修改的影响计算到当前节点中
        tree[o] += val * (Math.min(r, right) - Math.max(l, left) + 1);
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (left <= mid) updateRange(left, right, val, x, l, mid);
        if (right > mid) updateRange(left, right, val, y, mid + 1, r);
        // 不用pushUp操作
    }

    public long queryRange(final int left, final int right, int o, int l, int r) {
        if (left <= l && r <= right) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        // 累加路径上的tag标记(有效区间内的)
        long ans = tag[o] * (Math.min(r, right) - Math.max(l, left) + 1);
        if (left <= mid) ans += queryRange(left, right, x, l, mid);
        if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
        return ans;
    }
}

标记下传

标记下传:把一个节点的懒惰标记传给它的左右儿子,再把该节点的懒惰标记删去

当执行到某一节点时,先下传当前节点的标记,再查询或更新,最后 \(pushUp\) 的就是正确结果

// 标记下传(若下方标记与区间边界无关,则不需要l,r参数)
private void pushDown(int o, int x, int y, int l, int r) {
    // 空标记直接退出
    if (tag[o] == 0) return;
    int mid = l + r >> 1;
    // 下传给左节点
    tag[x] += tag[o];
    tree[x] += tag[o] * (mid - l + 1);
    // 下传给右节点
    tag[y] += tag[o];
    tree[y] += tag[o] * (r - (mid + 1) + 1);
    // 清空当前节点标记
    tag[o] = 0;
}
public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
    if (left <= l && r <= right) {
        tree[o] += val * (r - l + 1);
        tag[o] += val;
        return;
    }
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 将后续修改的影响计算到当前节点中
    // 下放标记
    pushDown(o, x, y, l, r);
    if (left <= mid) updateRange(left, right, val, x, l, mid);
    if (right > mid) updateRange(left, right, val, y, mid + 1, r);
    // 上传子区间数据
    pushUp(o, x, y);
}
public long queryRange(final int left, final int right, int o, int l, int r) {
    if (left <= l && r <= right) return tree[o];
    final int mid = l + r >> 1, x = o << 1, y = x | 1;
    // 将后续修改的影响计算到当前节点中
    // 下放标记
    pushDown(o, x, y, l, r);
    long ans = 0;
    if (left <= mid) ans += queryRange(left, right, x, l, mid);
    if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
    // 上传子区间数据
    pushUp(o, x, y);
    return ans;
}

Code

点击查看代码
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read();
        int[] a = new int[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = read();
        SegmentTree seg = new SegmentTree(a, n);
        while (m-- != 0) {
            int command = read();
            if (command == 1) {
                int x = read(), y = read(), k = read();
                seg.updateRange(x, y, k, 1, 1, n);
            } else {
                int x = read(), y = read();
                out.println(seg.queryRange(x, y, 1, 1, n));
            }
        }
        out.close();
    }
}

class SegmentTree {
    long[] tree, tag;
    int n;

    // 标记下传
    private void pushDown(int o, int x, int y, int l, int r) {
        // 空标记直接退出
        if (tag[o] == 0) return;
        int mid = l + r >> 1;
        // 下传给左节点
        tag[x] += tag[o];
        tree[x] += tag[o] * (mid - l + 1);
        // 下传给右节点
        tag[y] += tag[o];
        tree[y] += tag[o] * (r - (mid + 1) + 1);
        // 清空当前节点标记
        tag[o] = 0;
    }

    private void pushUp(int o, int x, int y) {
        tree[o] = tree[x] + tree[y];
    }

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l];
            return;
        }
        int mid = l + r >> 1, x = o << 1, y = x | 1;
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new long[n << 2];
        tag = new long[n << 2];
    }

    // 请保证数组数据下标从1开始
    public SegmentTree(int[] val, int _n) {
        // assert(val.length >= _n);
        this(_n);
        build(1, 1, n, val);
    }

    public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
        if (left <= l && r <= right) {
            tree[o] += val * (r - l + 1);
            tag[o] += val;
            return;
        }
        // 将后续修改的影响计算到当前节点中
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        // 下放标记
        pushDown(o, x, y, l, r);
        if (left <= mid) updateRange(left, right, val, x, l, mid);
        if (right > mid) updateRange(left, right, val, y, mid + 1, r);
        // 上传子区间数据
        pushUp(o, x, y);
    }

    public long queryRange(final int left, final int right, int o, int l, int r) {
        if (left <= l && r <= right) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        // 下放标记
        pushDown(o, x, y, l, r);
        long ans = 0;
        if (left <= mid) ans += queryRange(left, right, x, l, mid);
        if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
        // 上传子区间数据
        pushUp(o, x, y);
        return ans;
    }
}

空间优化

下面给出P3372 【模板】线段树 1 - 洛谷空间优化的线段树类代码

\(2n\) 空间

对于上述的线段树,要用到 \(4n\) 的空间,但只有 \(2n-1\) 个空间有作用,能不能只建立 \(2n\) 个空间?

深度优先搜索 \(DFS\) 是树的一种遍历方式,而 \(DFS\) 序是深度优先搜索中的节点访问次序,记为 \(DFN\),选择按照 \(DFN\) 的方式存储线段树节点

若某一个节点的编号为 \(p\),则其左儿子节点编号为 \(p+1\),则其右儿子节点编号为 \(p+左子树节点个数+1\)(因为是先遍历左子树嘛)

那左子树节点个数该怎么求呢?

在线段树的存储中提到

一颗管理数组长度为 \(n\) 的线段树的节点个数为 \(2\times n -1\)

若当前节点管理区间为 \([l,r]\),设 \(mid=\lfloor\dfrac{l+r}{2}\rfloor\),则左子树管理区间为 \([l,mid]\),左子树管理区间长度为 \(mid-l+1\),所以左子树节点个数为 \(2\times(\lfloor\dfrac{l+r}{2}\rfloor-l+1)-1=2\times\lfloor\dfrac{r-l+2}{2}\rfloor-1\)

因此,右儿子节点编号为 \(p+2\times\lfloor\dfrac{r-l+2}{2}\rfloor\)

\(2\) 向下取整代表按位右移,乘 \(2\) 代表按位左移。

因此,\(2\times\lfloor\dfrac{r-l+2}{2}\rfloor\) 可用 \(r-l+2\) 并将二进制最低位置为 \(0\) 表示

(r - l + 2) & ~1

点击查看代码
class SegmentTree {
    long[] tree, tag;
    int n;

    private void pushUp(int o, int x, int y) {
        tree[o] = tree[x] + tree[y];
    }

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l];
            return;
        }
        final int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new long[n << 1];
        tag = new long[n << 1];
    }

    // 请保证数组数据下标从1开始
    public SegmentTree(int[] val, int _n) {
        // assert(val.length >= _n);
        this(_n);
        build(1, 1, n, val);
    }

    private void pushDown(int o, int x, int y, int l, int r) {
        if (tag[o] == 0) return;
        int mid = l + r >> 1;
        tag[x] += tag[o];
        tree[x] += tag[o] * (mid - l + 1);
        tag[y] += tag[o];
        tree[y] += tag[o] * (r - mid);
        tag[o] = 0;
    }

    public void updateRange(final int left, final int right, final int val, int o, int l, int r) {
        if (left <= l && r <= right) {
            tree[o] += val * (r - l + 1);
            tag[o] += val;
            return;
        }
        final int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) updateRange(left, right, val, x, l, mid);
        if (right > mid) updateRange(left, right, val, y, mid + 1, r);
        pushUp(o, x, y);
    }

    public long queryRange(final int left, final int right, int o, int l, int r) {
        if (left <= l && r <= right) return tree[o];
        final int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        long ans = 0;
        if (left <= mid) ans += queryRange(left, right, x, l, mid);
        if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
        pushUp(o, x, y);
        return ans;
    }
}

动态开点

所有节点只在使用时才\(new\)申请内存,\(Cpp\) 通过指针、\(Java\) 通过引用的方式动态开点

如果动态开点进行 \(build\) 建树操作,就会将所有节点创建出来,就和上面 \(2n\) 空间一样了

因此,动态开点一般是不建树的

数的加法是不影响初值的,我们将初值取出,只对全为 \(0\) 的线段树进行区间加法和查询

在总的查询的时候加上原数组的值即可

最坏的情况就是所有操作都走不同的路径、且走到叶子节点,则一次操作会增加 \(\log n\) 个节点

\(n\) 为查询区间长度,\(m\) 为询问操作次数,则空间复杂度为:\(O(min(2n-1,\ m\log n))\)

点击查看代码
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read();
        long[] a = new long[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = a[i - 1] + read();
        SegmentTree seg = new SegmentTree(n);
        while (m-- != 0) {
            int command = read();
            if (command == 1) {
                int x = read(), y = read(), k = read();
                seg.modify(x, y, k);
            } else {
                int x = read(), y = read();
                long ans = a[y] - a[x - 1];
                ans += seg.query(x, y).sum;
                out.println(ans);
            }
        }
        out.close();
    }
}

class SegmentTree {
    class node {
        // 设置节点默认空白初始值, 用于答案查询及创建节点
        long sum = 0, add = 0;
        node lChild, rChild;

        // val 加到data和tag上, 用于区间修改终止和标记下传
        // 按照需要选择是否需要左右边界
        void apply(int l, int r, final long val) {
            sum += (r - l + 1) * val;
            add += val;
        }

        // 创建儿子节点
        public void addNode() {
            if (lChild == null) lChild = new node();
            if (rChild == null) rChild = new node();
        }
    }

    // 标记下传, 将cur节点的标记下传至两个子树中
    // 按照需要选择是否需要左右边界
    void pushDown(node cur, int l, int r) {
        if (cur.add != 0) {
            int mid = l + r >> 1;
            cur.lChild.apply(l, mid, cur.add);
            cur.rChild.apply(mid + 1, r, cur.add);
            cur.add = 0;
        }
    }

    // son的data数据加到cur上, 用于pushUp上传数据 和 查询时合并答案
    // 不用理会标记数值(前提是node有默认初始值,且代表空标记)
    void unite(node cur, final node son) {
        cur.sum += son.sum;
    }

    // 子区间数据上传
    void pushUp(node cur) {
        cur.sum = 0;
        unite(cur, cur.lChild);
        unite(cur, cur.rChild);
    }

    private void modify(node cur, int l, int r, final int left, final int right, final long val) {
        if (left <= l && r <= right) {
            cur.apply(l, r, val);
            return;
        }
        cur.addNode();
        pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) modify(cur.lChild, l, mid, left, right, val);
        if (right > mid) modify(cur.rChild, mid + 1, r, left, right, val);
        pushUp(cur);
    }

    private void query(node cur, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, cur);
            return;
        }
        cur.addNode();
        pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) query(cur.lChild, l, mid, left, right, res);
        if (right > mid) query(cur.rChild, mid + 1, r, left, right, res);
        pushUp(cur);
    }

    int n;
    node root;

    // 请保证node处进行了默认初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        root = new node();
    }

    // 区间修改
    public void modify(int left, int right, final int val) {
        // assert (1 <= left && left <= right && right <= n);
        modify(root, 1, n, left, right, val);
    }
    
    // 区间查询
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(root, 1, n, left, right, res);
        return res;
    }
}

要点总结

如果觉得我讲的不是很明白,可以看参考资料中提到的文章或者 \(Bilibili\)

  • 树状数组不同,线段树只需要两个区间信息可以合并,即维护的信息只需要满足结合律(如加法、乘法、异或、最大公约数等)
    结合律:\((x\circ y)\circ z=x\circ(y\circ z)\),其中 \(\circ\) 是一个二元运算符。

  • 带懒惰标记的线段树修改和查询操作时间复杂度均为 \(O(\log n)\),建树时间复杂度为 \(O(n)\)

  • 对标记的两种操作各有各的优点

    标记下传的实用性更广

    标记永久化思想还可用于可持久化数据结构

  • 线段树对于非强制在线的问题可以通过离散化缩小数据范围来减少空间

    而对于强制在线的问题就只能通过动态开点来减少空间了(应该

    当然,离散化也可以和动态开点搭配

    强制在线:不提前给出所有涉及询问和修改的区间范围,不能进行离散化

  • 线段树是一种工具,许多问题可以借助这个工具解决,就如同滑动窗口可以借助双端队列解决一样

线段树封装类

数组有效数据下标均从 \(1\) 开始

基础的不带懒标记的线段树暂时不提供了,因为对于这类问题,树状数组大多可以解决

如果您有更好的封装类能提供给我,我将感激不尽

懒标记线段树

C++

点击查看代码
template <typename T>
class SegmentTree {
public:
    struct node {
        // 设置叶子节点默认初始值, 用于不传数组的建树以及空标记
        T data = ...;
        T tag = ...;
        // val 加到data和tag上, 用于区间修改终止和标记下传
        // 按照需要选择是否需要左右边界
        void apply(..., const T &val) {
            ...
            // sum += (r - l + 1) * val;
            // add += val;
        }
        // 建树时传入数组的初始化
        void init(const T &val) {
            ...
            // sum = val;
        }
    };
    // 标记下传, 将o节点的标记下传至两个子树x,y中
    // 按照需要选择是否需要左右边界
    void pushDown(int o, int x, int y) {
        ...
        // if (tree[o].add != 0) {
        //     int mid = l + r >> 1;
        //     // 下传标记至左子树
        //     tree[x].apply(l, mid, tree[o].add);
        //     // 下传标记至右子树
        //     tree[y].apply(mid + 1, r, tree[o].add);
        //     // 清空当前节点标记
        //     tree[o].add = 0;
        // }
    }
    // son的data数据加到o上, 用于pushUp上传数据 和 查询时合并答案
    // 不用理会标记数值(前提是node有默认初始值,且代表空标记)
    void unite(node& o, const node& son) {
        ...
        // o.sum += son.sum;
    }
    // 子区间数据上传
    void pushUp(int o, int x, int y) {
        // 清空当前节点data
        ...
        // tree[o].sum = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }
    void build(int o, int l, int r) {
        if (l == r) return;
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid);
        build(y, mid + 1, r);
        pushUp(o, x, y);
    }
    void build(int o, int l, int r, const std::vector<T> &val) {
        if (l == r) {
            tree[o].init(val[l]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }
    void modify(int o, int l, int r, const int &left, const int &right, const T &val) {
        if (left <= l && r <= right) {
            tree[o].apply(..., val);
            // tree[o].apply(l, r, val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y);
        // pushDown(o, x, y, l, r);
        if (left <= mid) modify(x, l, mid, left, right, val);
        if (right > mid) modify(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }
    void query(int o, int l, int r, const int& left, const int& right, node& res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }
    int n;
    std::vector<node> tree;
    // 不传入数组的默认建树, 请保证node处进行了默认初始化
    SegmentTree(int _n) : n(_n) {
        assert(n > 0);
        tree.resize(n << 1);
        build(1, 1, n);
    }
    // 传入数组的建树, 请保证数组有效数据下标从1开始
    SegmentTree(std::vector<T>& val, int _n) : n(_n) {
        assert((int)val.size() >= _n);
        tree.resize(n << 1);
        build(1, 1, n, val);
    }
    // 单点修改
    void modify(const int& index, const T& val) {
        assert(1 <= index && index <= n);
        modify(1, 1, n, index, index, val);
    }
    // 区间修改
    void modify(const int& left, const int& right, const T& val) {
        assert(1 <= left && left <= right && right <= n);
        modify(1, 1, n, left, right, val);
    }
    // 单点查询
    node query(const int& index) {
        assert(1 <= index && index <= n);
        node res{};
        query(1, 1, n, index, index, res);
        return res;
    }
    // 区间查询
    node query(const int& left, const int& right) {
        assert(1 <= left && left <= right && right <= n);
        node res{};
        query(1, 1, n, left, right, res);
        return res;
    }
};

Java

由于本人对于 \(Java\) 泛型还不熟悉,以后学明白了再修改(挖坑,希望会填

点击查看代码
class SegmentTree {
    class node {
        int data, tag;

        public node(int _data, int _tag) {
            data = _data;
            tag = _tag;
        }

        // 建树时传入数组的初始化(标记置空)
        public node(int _data) {this(_data, 0);}

        // 设置叶子节点默认初始值, 用于不传数组的建树、置空标记以及查询答案的初始化
        public node() {this(0, 0);}

        // val 加到data和tag上, 用于区间修改终止和标记下传
        // 按照需要选择是否需要左右边界
        void apply(...,final int val) {
            ...
            // sum += (r - l + 1) * val;
            // add += val;
        }
    }

    // 标记下传, 将o节点的标记下传至两个子树x,y中
    // 按照需要选择是否需要左右边界
    private void pushDown(int o, int x, int y) {
        ...
        // if (tree[o].add != 0) {
        //     int mid = l + r >> 1;
        //     // 下传标记至左子树
        //     tree[x].apply(l, mid, tree[o].add);
        //     // 下传标记至右子树
        //     tree[y].apply(mid + 1, r, tree[o].add);
        //     // 清空当前节点标记
        //     tree[o].add = 0;
        // }
    }

    // son的data数据加到o上, 用于pushUp上传数据 和 查询时合并答案
    // 不用理会标记数值(前提是node有默认初始值,且代表空标记)
    private void unite(node o, final node son) {
        ...
        // o.sum += son.sum;
    }

    // 子区间数据上传
    private void pushUp(int o, int x, int y) {
        // 清空当前节点data
        ...
        // tree[o].sum = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }

    public void build(int o, int l, int r) {
        if (l == r) {
            tree[o] = new node();
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid);
        build(y, mid + 1, r);
        pushUp(o, x, y);
    }

    public void build(int o, int l, int r, final int[] val) {
        if (l == r) {
            tree[o] = new node(val[l]);
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    private void modify(int o, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
            tree[o].apply(...,val);
            // tree[o].apply(l, r, val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y);
        // pushDown(o, x, y, l, r);
        if (left <= mid) modify(x, l, mid, left, right, val);
        if (right > mid) modify(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }

    private void query(int o, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y);
        // pushDown(o, x, y, l, r);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }

    private int n;
    private node[] tree;
    
    // 不传入数组的默认建树, 请保证node处进行了默认初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        tree = new node[n << 1];
        build(1, 1, n);
    }

    // 传入数组的建树, 请保证数组有效数据下标从1开始
    public SegmentTree(final int[] val, int _n) {
        // assert ((int) val.length >= _n);
        n = _n;
        tree = new node[n << 1];
        build(1, 1, n, val);
    }

    // 单点修改
    public void modify(int index, final int val) {
        // assert (1 <= index && index <= n);
        modify(1, 1, n, index, index, val);
    }

    // 区间修改
    public void modify(int left, int right, final int val) {
        // assert (1 <= left && left <= right && right <= n);
        modify(1, 1, n, left, right, val);
    }

    // 单点查询
    public node query(int index) {
        // assert (1 <= index && index <= n);
        node res = new node();
        query(1, 1, n, index, index, res);
        return res;
    }

    // 区间查询
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(1, 1, n, left, right, res);
        return res;
    }
}

动态开点的懒标记线段树

Java

点击查看代码
class SegmentTree {
    class node {
        // 设置节点默认空白初始值, 用于答案查询及创建节点
        int data = 0, tag = 0;
        node lChild, rChild;

        // val 加到data和tag上, 用于区间修改终止和标记下传
        // 按照需要选择是否需要左右边界
        private void apply(...,final int val) {
            ...
            // sum += (r - l + 1) * val;
            // add += val;
        }

        // 创建子节点
        private void addNode() {
            if (lChild == null) lChild = new node();
            if (rChild == null) rChild = new node();
        }
    }

    // 标记下传, 将cur节点的标记下传至两个子树中
    // 按照需要选择是否需要左右边界
    private void pushDown(node cur, ...) {
        ...
        // if (cur.add != 0) {
        //     int mid = l + r >> 1;
        //     // 下传标记至左子树
        //     cur.lChild.apply(l, mid, cur.add);
        //     // 下传标记至右子树
        //     cur.rChild.apply(mid + 1, r, cur.add);
        //     // 清空当前节点标记
        //     cur.add = 0;
        // }
    }

    // son的data数据加到cur上, 用于pushUp上传数据 和 查询时合并答案
    // 不用理会标记数值(前提是node有默认初始值,且代表空标记)
    private void unite(node cur, final node son) {
        ...
        // cur.sum += son.sum;
    }

    // 子区间数据上传
    private void pushUp(node cur) {
        // 清空当前节点data
        ...
        // cur.sum = 0;
        unite(cur, cur.lChild);
        unite(cur, cur.rChild);
    }

    private void modify(node cur, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
            cur.apply(...,val);
            // cur.apply(l, r, val);
            return;
        }
        cur.addNode();
        pushDown(cur);
        // pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) modify(cur.lChild, l, mid, left, right, val);
        if (right > mid) modify(cur.rChild, mid + 1, r, left, right, val);
        pushUp(cur);
    }

    private void query(node cur, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, cur);
            return;
        }
        cur.addNode();
        pushDown(cur);
        // pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) query(cur.lChild, l, mid, left, right, res);
        if (right > mid) query(cur.rChild, mid + 1, r, left, right, res);
        pushUp(cur);
    }

    private int n;
    private node root;

    // 请保证node处进行了默认初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        root = new node();
    }

    // 单点修改
    public void modify(int index, final int val) {
        // assert (1 <= index && index <= n);
        modify(root, 1, n, index, index, val);
    }

    // 区间修改
    public void modify(int left, int right, final int val) {
        // assert (1 <= left && left <= right && right <= n);
        modify(root, 1, n, left, right, val);
    }

    // 单点查询
    public node query(int index) {
        // assert (1 <= index && index <= n);
        node res = new node();
        query(root, 1, n, index, index, res);
        return res;
    }

    // 区间查询
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(root, 1, n, left, right, res);
        return res;
    }
}

C++

箭头操作符\(=\)解引用\(+\)点操作符

p->data(*p).data相同

点击查看代码
template <typename T>
class SegmentTree {
private:
    struct node {
        // 设置节点默认空白初始值, 用于答案查询及创建节点
        T data = ..., tag = ...;
        node* lChild = nullptr;
        node* rChild = nullptr;
        // val 加到data和tag上, 用于区间修改终止和标记下传
        // 按照需要选择是否需要左右边界
        void apply(..., const T& val) {
            ...
            // sum += (r - l + 1) * val;
            // add += val;
        }
        // 创建儿子节点
        void addNode() {
            if (!lChild) lChild = new node();
            if (!rChild) rChild = new node();
        }
    };
    // 标记下传, 将cur节点的标记下传至两个子树中
    // 按照需要选择是否需要左右边界
    void pushDown(node* cur, ...) {
        ...
        // if (cur->add != 0) {
        //     int mid = l + r >> 1;
        //     cur->lChild->apply(l, mid, cur->add);
        //     cur->rChild->apply(mid + 1, r, cur->add);
        //     cur->add = 0;
        // }
    }

    // son的data数据加到cur上, 用于pushUp上传数据 和 查询时合并答案
    // 不用理会标记数值(前提是node有默认初始值,且代表空标记)
    void unite(node* cur, const node* son) {
        ...
        // cur->sum += son->sum;
    }

    // 子区间数据上传
    void pushUp(node* cur) {
        // 清空当前节点data
        ...
        // cur->sum = 0;
        unite(cur, cur->lChild);
        unite(cur, cur->rChild);
    }
    void modify(node* cur, int l, int r, const int& left, const int& right, const T& val) {
        if (left <= l && r <= right) {
            cur->apply(..., val);
            // cur->apply(l, r, val);
            return;
        }
        cur->addNode();
        pushDown(cur);
        // pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) modify(cur->lChild, l, mid, left, right, val);
        if (right > mid) modify(cur->rChild, mid + 1, r, left, right, val);
        pushUp(cur);
    }

    void query(node* cur, int l, int r, const int& left, const int& right, node* res) {
        if (left <= l && r <= right) {
            unite(res, cur);
            return;
        }
        cur->addNode();
        pushDown(cur);
        // pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) query(cur->lChild, l, mid, left, right, res);
        if (right > mid) query(cur->rChild, mid + 1, r, left, right, res);
        pushUp(cur);
    }

    int n;
    node* root;

public:
    // 请保证node处进行了默认初始化
    SegmentTree(const int& _n) : n(_n), root(new node()) {
        assert(n > 0);
    }
    // 单点修改
    void modify(const int& index, const T& val) {
        assert(1 <= index && index <= n);
        modify(root, 1, n, index, index, val);
    }
    // 区间修改
    void modify(const int& left, const int& right, const T& val) {
        assert(1 <= left && left <= right && right <= n);
        modify(root, 1, n, left, right, val);
    }
    // 单点查询
    node* query(const int& index) {
        assert(1 <= index && index <= n);
        node* res = new node();
        query(root, 1, n, index, index, res);
        return res;
    }
    // 区间查询
    node* query(const int& left, const int& right) {
        assert(1 <= left && left <= right && right <= n);
        node* res = new node();
        query(root, 1, n, left, right, res);
        return res;
    }
};

题目

P3373 线段树2 - 洛谷

题目链接

题意简述:有三个操作

  1. 对区间 \([l,r]\) 每个数乘上 \(k\)
  2. 对区间 \([l,r]\) 每个数加上 \(k\)
  3. 查询区间 \([l,r]\) 每个数的和

两种修改操作对应两种懒惰标记, 优先下传乘法标记,将乘法标记对加法标记的影响算在加法标记中

Java Code

点击查看代码
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static int _num, _sign, _char;

    static int read() throws IOException {
        for (_sign = 1, _char = br.read(); _char > '9' || _char < '0'; _char = br.read()) if (_char == '-') _sign = -1;
        for (_num = 0; '0' <= _char && _char <= '9'; _char = br.read()) _num = _num * 10 + _char - '0';
        return _num * _sign;
    }

    public static void main(String[] args) throws IOException {
        PrintWriter out = new PrintWriter(System.out);
        int n = read(), m = read(), mod = read();
        long[] a = new long[n + 1];
        for (int i = 1; i <= n; ++i) a[i] = read();
        SegmentTree seg = new SegmentTree(a, n);
        while (m-- != 0) {
            int command = read(), x = read(), y = read();
            if (command == 1) seg.modifyMul(x, y, read());
            else if (command == 2) seg.modifyAdd(x, y, read());
            else out.println(seg.query(x, y).sum);
        }
        out.close();
    }
}

class SegmentTree {
    static final int mod = 571373;

    class node {
        long sum, add, mul;

        public node(long _sum, long _add, long _mul) {
            sum = _sum % mod;
            add = _add % mod;
            mul = _mul % mod;
        }

        public node(long _sum) {this(_sum, 0, 1);}

        public node() {this(0, 0, 1);}

        // val 加到data和tag上, 用于区间修改终止和标记下传
        // 按照需要选择是否需要左右边界
        void applyAdd(int l, int r, final long val) {
            sum = (sum + (r - l + 1) * val) % mod;
            add = (add + val) % mod;
        }

        void applyMul(final long val) {
            sum = sum * val % mod;
            mul = mul * val % mod;
            add = add * val % mod;
        }
    }

    void pushDown(int o, int x, int y, int l, int r) {
        if (tree[o].mul != 1) {
            tree[x].applyMul(tree[o].mul);
            tree[y].applyMul(tree[o].mul);
            tree[o].mul = 1;
        }
        if (tree[o].add != 0) {
            int mid = l + r >> 1;
            tree[x].applyAdd(l, mid, tree[o].add);
            tree[y].applyAdd(mid + 1, r, tree[o].add);
            tree[o].add = 0;
        }
    }

    void unite(node o, final node son) {
        o.sum = (o.sum + son.sum) % mod;
    }

    void pushUp(int o, int x, int y) {
        tree[o].sum = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }

    public void build(int o, int l, int r, final long[] val) {
        if (l == r) {
            tree[o] = new node(val[l]);
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    private void modifyAdd(int o, int l, int r, final int left, final int right, final long val) {
        if (left <= l && r <= right) {
            tree[o].applyAdd(l, r, val);
            // tree[o].apply(l, r, val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) modifyAdd(x, l, mid, left, right, val);
        if (right > mid) modifyAdd(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }

    private void modifyMul(int o, int l, int r, final int left, final int right, final long val) {
        if (left <= l && r <= right) {
            tree[o].applyMul(val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) modifyMul(x, l, mid, left, right, val);
        if (right > mid) modifyMul(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }

    private void query(int o, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }

    int n;
    node[] tree;

    public SegmentTree(final long[] val, int _n) {
        // assert ((int) val.length >= _n);
        n = _n;
        tree = new node[n << 1];
        build(1, 1, n, val);
    }

    // 区间修改1
    void modifyMul(final int left, final int right, final long val) {
        assert (1 <= left && left <= right && right <= n);
        modifyMul(1, 1, n, left, right, val);
    }

    // 区间修改2
    void modifyAdd(int left, int right, final long val) {
        // assert (1 <= left && left <= right && right <= n);
        modifyAdd(1, 1, n, left, right, val);
    }

    // 区间查询
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(1, 1, n, left, right, res);
        return res;
    }
}

C++20 Code

用到了自动取模类,详见逆元一文

点击查看代码
#include <bits/stdc++.h>
using namespace std;

template <typename T>
class SegmentTree {
public:
    struct node {
        // 设置叶子节点默认初始值, 用于不传数组的建树以及空标记
        T sum = 0;
        T add = 0;
        T mul = 1;
        // val 加到data和tag上, 用于区间修改终止和标记下传
        // 按照需要选择是否需要左右边界
        void applyAdd(int l, int r, const T& val) {
            sum += T(r - l + 1) * val;
            add += val;
        }
        void applyMultiply(const T& val) {
            sum *= val;
            mul *= val;
            add *= val;
        }
        // 建树时传入数组的初始化
        void init(const T& val) {
            sum = val;
        }
    };
    // 标记下传, 将o节点的标记下传至两个子树x,y中
    // 按照需要选择是否需要左右边界
    void pushDown(int o, int x, int y, int l, int r) {
        if (tree[o].mul != 1) {
            tree[x].applyMultiply(tree[o].mul);
            tree[y].applyMultiply(tree[o].mul);
            tree[o].mul = 1;
        }
        if (tree[o].add != 0) {
            int mid = l + r >> 1;
            tree[x].applyAdd(l, mid, tree[o].add);
            tree[y].applyAdd(mid + 1, r, tree[o].add);
            tree[o].add = 0;
        }
    }
    // son的data数据加到o上, 用于pushUp上传数据 和 查询时合并答案
    // 不用理会标记数值(前提是node有默认初始值,且代表空标记)
    void unite(node& o, const node& son) {
        o.sum += son.sum;
    }
    // 子区间数据上传
    void pushUp(int o, int x, int y) {
        tree[o].sum = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }
    void build(int o, int l, int r, const std::vector<T>& val) {
        if (l == r) {
            tree[o].init(val[l]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }
    void modifyAdd(int o, int l, int r, const int& left, const int& right, const T& val) {
        if (left <= l && r <= right) {
            tree[o].applyAdd(l, r, val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) modifyAdd(x, l, mid, left, right, val);
        if (right > mid) modifyAdd(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }
    void modifyMultiply(int o, int l, int r, const int& left, const int& right, const T& val) {
        if (left <= l && r <= right) {
            tree[o].applyMultiply(val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) modifyMultiply(x, l, mid, left, right, val);
        if (right > mid) modifyMultiply(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }
    void query(int o, int l, int r, const int& left, const int& right, node& res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y, l, r);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }
    int n;
    std::vector<node> tree;
    // 传入数组的建树, 请保证数组有效数据下标从1开始
    SegmentTree(std::vector<T>& val, int _n) : n(_n) {
        assert((int)val.size() >= _n);
        tree.resize(n << 1);
        build(1, 1, n, val);
    }
    // 区间修改1
    void modifyMultiply(const int& left, const int& right, const T& val) {
        assert(1 <= left && left <= right && right <= n);
        modifyMultiply(1, 1, n, left, right, val);
    }
    // 区间修改2
    void modifyAdd(const int& left, const int& right, const T& val) {
        assert(1 <= left && left <= right && right <= n);
        modifyAdd(1, 1, n, left, right, val);
    }
    // 区间查询
    node query(const int& left, const int& right) {
        assert(1 <= left && left <= right && right <= n);
        node res{};
        query(1, 1, n, left, right, res);
        return res;
    }
};

template <int MOD>
struct modint {
    int val;
    static int norm(const int& x) { return x < 0 ? x + MOD : x; }
    static constexpr int get_mod() { return MOD; }
    modint inv() const {
        assert(val);
        int a = val, b = MOD, u = 1, v = 0, t;
        while (b > 0) t = a / b, swap(a -= t * b, b), swap(u -= t * v, v);
        assert(b == 1);
        return modint(u);
    }
    modint() : val(0) {}
    modint(const int& m) : val(norm(m)) {}
    modint(const long long& m) : val(norm(m % MOD)) {}
    modint operator-() const { return modint(norm(-val)); }
    bool operator==(const modint& o) { return val == o.val; }
    bool operator<(const modint& o) { return val < o.val; }
    modint& operator+=(const modint& o) { return val = (1ll * val + o.val) % MOD, *this; }
    modint& operator-=(const modint& o) { return val = norm(1ll * val - o.val), *this; }
    modint& operator*=(const modint& o) { return val = static_cast<int>(1ll * val * o.val % MOD), *this; }
    modint& operator/=(const modint& o) { return *this *= o.inv(); }
    modint& operator^=(const modint& o) { return val ^= o.val, *this; }
    modint& operator>>=(const modint& o) { return val >>= o.val, *this; }
    modint& operator<<=(const modint& o) { return val <<= o.val, *this; }
    modint operator-(const modint& o) const { return modint(*this) -= o; }
    modint operator+(const modint& o) const { return modint(*this) += o; }
    modint operator*(const modint& o) const { return modint(*this) *= o; }
    modint operator/(const modint& o) const { return modint(*this) /= o; }
    modint operator^(const modint& o) const { return modint(*this) ^= o; }
    modint operator>>(const modint& o) const { return modint(*this) >>= o; }
    modint operator<<(const modint& o) const { return modint(*this) <<= o; }
    friend std::istream& operator>>(std::istream& is, modint& a) {
        long long v;
        return is >> v, a.val = norm(v % MOD), is;
    }
    friend std::ostream& operator<<(std::ostream& os, const modint& a) { return os << a.val; }
    friend std::string tostring(const modint& a) { return std::to_string(a.val); }
    friend modint qpow(const modint& a, const int& b) {
        assert(b >= 0);
        modint x = a, res = 1;
        for (int p = b; p; x *= x, p >>= 1)
            if (p & 1) res *= x;
        return res;
    }
};

constexpr int mod = 571373;
using Mint = modint<mod>;

signed main() {
    std::ios_base::sync_with_stdio(false), std::cin.tie(nullptr), std::cout.tie(nullptr);
    int n, m, command, x, y;
    cin >> n >> m >> x;
    vector<Mint> a(n + 1);
    for (int i = 1; i <= n; ++i) cin >> a[i];
    SegmentTree<Mint> seg(a, n);
    Mint k;
    while (m--) {
        cin >> command >> x >> y;
        if (command == 1) {
            cin >> k;
            seg.modifyMultiply(x, y, k);
        } else if (command == 2) {
            cin >> k;
            seg.modifyAdd(x, y, k);
        } else {
            cout << seg.query(x, y).sum << endl;
        }
    }
    return 0;
}

315. 计算右侧小于当前元素的个数 - 力扣

题目链接

题意简述:求逆序对

求逆序对可以通过归并排序,也可以通过树状数组\(+\)离散化

树状数组可以做这道题,那线段树也一定可以

对于线段树可以选择动态开点,也可以选择离散化

这题数据范围很小,正常做好像也能过

Java 离散化

点击查看代码
class Solution {
    int n, max;
    // 去重
    int adjacentRemove(int[] nums) {
        int slow = 0;
        for (int fast = 1; fast < n; ++fast) {
            if (nums[slow] != nums[fast]) {
                nums[++slow] = nums[fast];
            }
        }
        return slow + 1;
    }
    //离散化
    void lis(int[] a) {
        int[] temp = new int[n];
        System.arraycopy(a, 0, temp, 0, n);
        Arrays.sort(temp);
        max = adjacentRemove(temp);
        for (int i = 0; i < n; ++i) {
            // 映射到 [1, max-1] 的区间上
            a[i] = Arrays.binarySearch(temp,0, max, a[i]) + 1;
        }
    }
    public List<Integer> countSmaller(int[] nums) {
        n = nums.length;
        lis(nums);
        SegmentTree seg = new SegmentTree(max);
        List<Integer> ans = new ArrayList<Integer>(n);
        // 找右侧 有多少个 比 当前数 小 的数
        for (int i = n - 1; i >= 0; --i) {
            seg.modify(nums[i], 1);
            if (nums[i] - 1 == 0) ans.add(0);
            else ans.add(seg.query(1, nums[i] - 1).sum);
        }
        Collections.reverse(ans);
        return ans;
    }
}

class SegmentTree {
    class node {
        int sum, add;

        public node(int _sum, int _add) {
            sum = _sum;
            add = _add;
        }

        public node(int _sum) {this(_sum, 0);}

        public node() {this(0, 0);}

        void apply(int l,int r,final int val) {
             sum += (r - l + 1) * val;
             add += val;
        }
    }

    void pushDown(int o, int x, int y,int l,int r) {
         if (tree[o].add != 0) {
             int mid = l + r >> 1;
             tree[x].apply(l, mid, tree[o].add);
             tree[y].apply(mid + 1, r, tree[o].add);
             tree[o].add = 0;
         }
    }

    void unite(node o, final node son) {
         o.sum += son.sum;
    }

    void pushUp(int o, int x, int y) {
         tree[o].sum = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }

    public void build(int o, int l, int r) {
        if (l == r) {
            tree[o] = new node();
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid);
        build(y, mid + 1, r);
        pushUp(o, x, y);
    }

    public void build(int o, int l, int r, final int[] val) {
        if (l == r) {
            tree[o] = new node(val[l]);
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    private void modify(int o, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
             tree[o].apply(l, r, val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
         pushDown(o, x, y, l, r);
        if (left <= mid) modify(x, l, mid, left, right, val);
        if (right > mid) modify(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }

    private void query(int o, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
         pushDown(o, x, y, l, r);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }

    int n;
    node[] tree;

    // 不传入数组的默认建树, 请保证node处进行了默认初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        tree = new node[n << 1];
        build(1, 1, n);
    }

    // 单点修改
    void modify(int index, final int val) {
        // assert (1 <= index && index <= n);
        modify(1, 1, n, index, index, val);
    }

    // 区间查询
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(1, 1, n, left, right, res);
        return res;
    }
}

Java 动态开点

点击查看代码
class Solution {
    public List<Integer> countSmaller(int[] nums) {
        final int max = (int) 1e4 + 1;
        int n = nums.length;
        // 有负数, 整体都加上 max, 保证数都大于 0
        SegmentTree seg = new SegmentTree(2 * max);
        List<Integer> ans = new ArrayList<Integer>(n);
        // 找右侧 有多少个 比 当前数 小 的数
        for (int i = n - 1; i >= 0; --i) {
            int val = nums[i] + max;
            seg.modify(val, 1);
            if (val - 1 == 0) ans.add(0);
            else ans.add(seg.query(1, val - 1).sum);
        }
        Collections.reverse(ans);
        return ans;
    }
}

class SegmentTree {
    class node {
        int sum = 0, add = 0;
        node lChild, rChild;

        void apply(int l, int r, final int val) {
            sum += (r - l + 1) * val;
            add += val;
        }

        public void addNode() {
            if (lChild == null) lChild = new node();
            if (rChild == null) rChild = new node();
        }
    }

    void pushDown(node cur, int l, int r) {
        if (cur.add != 0) {
            int mid = l + r >> 1;
            cur.lChild.apply(l, mid, cur.add);
            cur.rChild.apply(mid + 1, r, cur.add);
            cur.add = 0;
        }
    }

    void unite(node cur, final node son) {
        cur.sum += son.sum;
    }

    void pushUp(node cur) {
        cur.sum = 0;
        unite(cur, cur.lChild);
        unite(cur, cur.rChild);
    }

    private void modify(node cur, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
            cur.apply(l, r, val);
            return;
        }
        cur.addNode();
        pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) modify(cur.lChild, l, mid, left, right, val);
        if (right > mid) modify(cur.rChild, mid + 1, r, left, right, val);
        pushUp(cur);
    }

    private void query(node cur, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, cur);
            return;
        }
        cur.addNode();
        pushDown(cur, l, r);
        int mid = l + r >> 1;
        if (left <= mid) query(cur.lChild, l, mid, left, right, res);
        if (right > mid) query(cur.rChild, mid + 1, r, left, right, res);
        pushUp(cur);
    }

    int n;
    node root;

    // 请保证node处进行了默认初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        root = new node();
    }

    // 单点修改
    public void modify(int index, final int val) {
        // assert (1 <= index && index <= n);
        modify(root, 1, n, index, index, val);
    }

    // 区间查询
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(root, 1, n, left, right, res);
        return res;
    }
}

307. 区域和检索 - 数组可修改 - 力扣

题目链接

题意简述:有两个操作

  1. 单点赋值
  2. 区间和查询

单点赋值可以改为单点查询+单点修改(查询这个值再减去这个值)

当然也可以多开一个 \(nums\) 数组维护单点值

单点修改加区间查询,懒标记都不需要

再看数据范围,\(1 <= nums.length <= 3 \times 10^4\),离散化、动态开点也不需要

点击查看代码
class NumArray {
    int n;
    SegmentTree seg;
    int[] nums;
    public NumArray(int[] _nums) {
        n = _nums.length;
        seg = new SegmentTree(_nums, n);
        nums = _nums;
    }
    
    public void update(int index, int val) {
        seg.updateOne(index + 1, val - nums[index], 1, 1, n);
        nums[index] = val;
    }
    
    public int sumRange(int left, int right) {
        return seg.queryRange(left + 1, right + 1, 1, 1, n);
    }
}

class SegmentTree {
    int[] tree;
    int n;

    private void pushUp(int o, int x, int y) {
        tree[o] = tree[x] + tree[y];
    }

    public void build(int o, int l, int r, int[] val) {
        if (l == r) {
            tree[o] = val[l - 1];
            return;
        }
        int mid = l + r >> 1, x = o << 1, y = x | 1;
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    public SegmentTree(int _n) {
        n = _n;
        tree = new int[n << 2];
    }

    public SegmentTree(int[] val, int _n) {
        this(_n);
        build(1, 1, n, val);
    }

    public void updateOne(final int index, final int val, int o, int l, int r) {
        if (l == r) {
            tree[o] += val;
            return;
        }
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        if (index <= mid) updateOne(index, val, x, l, mid);
        else updateOne(index, val, y, mid + 1, r);
        pushUp(o, x, y);
    }

    public int queryRange(final int left, final int right, int o, int l, int r) {
        if (left <= l && r <= right) return tree[o];
        final int mid = l + r >> 1, x = o << 1, y = x | 1;
        int ans = 0;
        if (left <= mid) ans += queryRange(left, right, x, l, mid);
        if (right > mid) ans += queryRange(left, right, y, mid + 1, r);
        return ans;
    }
}

699. 掉落的方块 - 力扣

题目链接

题意简述:

俄罗斯方块,从上向下降落正方形,有交集则垫高,擦边的不算有交集,求最高的高度

每个右边界都\(-1\),解决擦边的问题

先查找 \([left,right]\) 区间内的最大值,再将 \([left,right]\) 赋值为该最大值+方块尺寸

题目变为 区间赋值 和 区间最大值查询(注意是区间赋值)

Java 动态开点

点击查看代码
class Solution {
    public List<Integer> fallingSquares(int[][] positions) {
        List<Integer> ans = new ArrayList<Integer>(positions.length);
        final int max = ((int) 1e8) + ((int) 1e6);
        SegmentTree seg = new SegmentTree(max);
        for (int[] v : positions) {
            int left = v[0], len = v[1], right = left + len - 1;
            int currentMax = seg.query(left, right).max;
            seg.modify(left, right, currentMax + len);
            ans.add(seg.query(1, max).max);
        }
        return ans;
    }
}

class SegmentTree {
    class node {
        int max = 0, assign = 0;
        node lChild, rChild;

        private void apply(final int val) {
            max = val;
            assign = val;
        }

        private void addNode() {
            if (lChild == null) lChild = new node();
            if (rChild == null) rChild = new node();
        }
    }

    private void pushDown(node cur) {
        if (cur.assign != 0) {
            cur.lChild.apply(cur.assign);
            cur.rChild.apply(cur.assign);
            cur.assign = 0;
        }
    }

    private void unite(node cur, final node son) {
        cur.max = Math.max(cur.max, son.max);
    }

    private void pushUp(node cur) {
        cur.max = 0;
        unite(cur, cur.lChild);
        unite(cur, cur.rChild);
    }

    private void modify(node cur, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
            cur.apply(val);
            return;
        }
        cur.addNode();
        pushDown(cur);
        int mid = l + r >> 1;
        if (left <= mid) modify(cur.lChild, l, mid, left, right, val);
        if (right > mid) modify(cur.rChild, mid + 1, r, left, right, val);
        pushUp(cur);
    }

    private void query(node cur, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, cur);
            return;
        }
        cur.addNode();
        pushDown(cur);
        int mid = l + r >> 1;
        if (left <= mid) query(cur.lChild, l, mid, left, right, res);
        if (right > mid) query(cur.rChild, mid + 1, r, left, right, res);
        pushUp(cur);
    }

    private int n;
    private node root;

    // 请保证node处进行了默认初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        root = new node();
    }

    // 区间修改
    public void modify(int left, int right, final int val) {
        // assert (1 <= left && left <= right && right <= n);
        modify(root, 1, n, left, right, val);
    }

    // 区间查询
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(root, 1, n, left, right, res);
        return res;
    }
}

Java 离散化

点击查看代码
class Solution {
    int[] rank;
    int cnt;

    int adjacentRemove(int[] a, int n) {
        int slow = 0;
        for (int fast = slow + 1; fast < n; ++fast) {
            if (a[slow] != a[fast] && ++slow != fast) {
                a[slow] = a[fast];
            }
        }
        return slow + 1;
    }

    void discrete(int[][] positions, int n) {
        rank = new int[n << 1];
        for (int i = 0; i < n; ++i) {
            rank[i << 1] = positions[i][0];
            rank[i << 1 | 1] = positions[i][0] + positions[i][1] - 1;
        }
        Arrays.sort(rank);
        cnt = adjacentRemove(rank, n << 1);
    }

    // 查找映射(大于0)
    int find(int val) {
        return Arrays.binarySearch(rank, 0, cnt, val) + 1;
    }

    public List<Integer> fallingSquares(int[][] positions) {
        int n = positions.length;
        List<Integer> ans = new ArrayList<Integer>(n);
        discrete(positions, n);
        SegmentTree seg = new SegmentTree(cnt);
        for (int[] v : positions) {
            int left = find(v[0]), len = v[1], right = find(v[0] + len - 1);
            int currentMax = seg.query(left, right).max;
            seg.modify(left, right, currentMax + len);
            ans.add(seg.query(1, cnt).max);
        }
        return ans;
    }
}

class SegmentTree {
    class node {
        int max, assign;

        public node(int _max, int _assign) {
            max = _max;
            assign = _assign;
        }

        public node(int _max) {this(_max, 0);}

        public node() {this(0, 0);}

        void apply(final int val) {
            max = val;
            assign = val;
        }
    }

    private void pushDown(int o, int x, int y) {
        if (tree[o].assign != 0) {
            tree[x].apply(tree[o].assign);
            tree[y].apply(tree[o].assign);
            tree[o].assign = 0;
        }
    }

    private void unite(node o, final node son) {
        o.max = Math.max(o.max, son.max);
    }

    private void pushUp(int o, int x, int y) {
        tree[o].assign = 0;
        unite(tree[o], tree[x]);
        unite(tree[o], tree[y]);
    }

    public void build(int o, int l, int r) {
        if (l == r) {
            tree[o] = new node();
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid);
        build(y, mid + 1, r);
        pushUp(o, x, y);
    }

    public void build(int o, int l, int r, final int[] val) {
        if (l == r) {
            tree[o] = new node(val[l]);
            return;
        }
        tree[o] = new node();
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        build(x, l, mid, val);
        build(y, mid + 1, r, val);
        pushUp(o, x, y);
    }

    private void modify(int o, int l, int r, final int left, final int right, final int val) {
        if (left <= l && r <= right) {
            tree[o].apply(val);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y);
        if (left <= mid) modify(x, l, mid, left, right, val);
        if (right > mid) modify(y, mid + 1, r, left, right, val);
        pushUp(o, x, y);
    }

    private void query(int o, int l, int r, final int left, final int right, node res) {
        if (left <= l && r <= right) {
            unite(res, tree[o]);
            return;
        }
        int mid = l + r >> 1, x = o + 1, y = o + ((r - l + 2) & ~1);
        pushDown(o, x, y);
        if (left <= mid) query(x, l, mid, left, right, res);
        if (right > mid) query(y, mid + 1, r, left, right, res);
        pushUp(o, x, y);
    }

    private int n;
    private node[] tree;

    // 不传入数组的默认建树, 请保证node处进行了默认初始化
    public SegmentTree(int _n) {
        // assert(n > 0);
        n = _n;
        tree = new node[n << 1];
        build(1, 1, n);
    }

    // 区间修改
    public void modify(int left, int right, final int val) {
        // assert (1 <= left && left <= right && right <= n);
        modify(1, 1, n, left, right, val);
    }

    // 区间查询
    public node query(int left, int right) {
        // assert (1 <= left && left <= right && right <= n);
        node res = new node();
        query(1, 1, n, left, right, res);
        return res;
    }
}

参考资料

线段树详解与实现 - 知乎

线段树详解 (原理,实现与应用) - AC_King

线段树从入门到急停 - yukiyama

一维线段树的2n空间实现

线段树节点个数的递推公式与通项公式 - Hoxily

关于线段树的数组到底是开2N还是4N - 知乎