线段树概念

线段树(segment Tree) 是一种基于分治思想的二叉树结构,用于区间上进行信息统计用的。

数据结构定义

  1. 线段树每个节点都代表一个区间。
  2. 线段树具有唯一根节点,代表整个区间。
  3. 线段树每个叶子节点代表一个长度为1的元区间。
  4. 对于每个内部节点,它的左子节点是[l, mid], 右子节点是[mid + 1, r], mid = (l + r ) / 2 ,向下取整。

数据结构操作

线段树的建立

关键在于递归的自下向上进行数据传递
当前节点的数据 = 左子树+ 右子树。
当然由于数据的不同, 数据的加法的方式也不同,
例如:

  1. 如果维护的区间最大值 : 当前的值= max(左子树, 右子树)
  2. 如果是维护区间和: 当前值 = 左子树 + 右子树。

拿这个题目为例子

struct SegmentTree {
   
    int l, r;
    long long int sum, add;
    #define l(x) t[x].l
    #define r(x) t[x].r
    #define sum(x) t[x].sum
    #define add(x) t[x].add
} t [SIZE * 4];

void build(int p, int l, int r) {
   
    l(p) = l, r(p) = r;
    if (l == r) {
   sum(p) = arr[l]; return;}
    int mid = (l + r) / 2;
    build(p * 2, l, mid);
    build(p * 2 + 1, mid + 1, r);
    sum(p) = sum(p * 2) + sum(p * 2 + 1);
    return;
}

线段树的修改

线段树的单点修改

就是先找到这个点,然后自底向上的进行修改
例如,问题是求区间最大和,然后修改某个节点的值。

void change(int p, int x, int v) {
   
    if (t[p].l == t[p].r) {
   t[p].data = v; return;}
    int mid = (t[p].l + t[p].r) / 2;
    if (x <= mid) change(p * 2, x, v);
    else change(p * 2 + 1, x, v);
    t[p].data = max(t[p * 2].data, t[p * 2 + 1].data);
    return;
}

线段数的延迟标记 + 区间修改

如果每次对区间进行修改需要落实到区间的每个元素上,那么时间复杂度会达到O(n),这是比较低效的行为。

那么针对这个问题,如果更新了节点p代表的区间[pl,pr],并将子树p中的所有节点都进行了更新,但是之后的查询过程中,根本没有用到这些区间的值,那么更新子树的行为就是徒劳的行为。

所以我们通过增加一个标识,标识该节点曾经被修改,但是其子节点没有被更新。

如果在后续的指令中,需要从节点p向下递归,我们再检查p是否具有标记,若有标记,就更新p的两个子节点,同时标记p的两个子节点,清楚p的标记。

也就是说,除了在修改指令中直接划分成的O(logN)个节点外,对任意节点的修改都延迟到“在后续操作中递归进入它的父节点时候”再执行。

例如这道模板题的延迟更新如下:

void spread(int p) {
   
    if (add(p)) {
   
        sum(p * 2) += add(p) * (r(p * 2) - l(p * 2) + 1);
        sum(p * 2 + 1) += add(p) * (r(p * 2 + 1) - l(p * 2 + 1) + 1);
        add(p * 2) += add(p);
        add(p * 2 + 1) += add(p);
        add(p) = 0;
    }
}

区间更改如下:

void change(int p, int l, int r, int d) {
   
    if (l <= l(p) && r >= r(p)) {
   
        sum(p) += (long long)d * (r(p) - l(p) + 1);
        add(p) += d;
        return;
    }
    spread(p);
    int mid = (l(p) + r(p)) / 2;
    if (l <= mid) change(p * 2, l, r, d);
    if (mid < r) change(p * 2 + 1, l, r, d);
    sum(p) = sum(p * 2) + sum(p * 2 + 1);
    return;
}

查找

  1. 若[l,r]覆盖当前节点代表的区间就立刻返回回溯。
  2. 若左子节点与[l,r]有交集,则递归左子树
  3. 若右子节点与[l,r]有交集,则递归右子树
long long ask(int p, int l, int r) {
   
    if (l <= l(p) && r >= r(p)) {
   
        return sum(p);
    }
    spread(p);
    int mid = (l(p) + r(p)) / 2;
    long long val = 0;
    if (l <= mid) val += ask(p * 2, l, r);
    if (mid < r) val += ask(p * 2 + 1, l, r);
    return val;
}

完整代码

#include <iostream>

using namespace std;

#define SIZE 100005

long long int arr[SIZE] = {
   0};

struct SegmentTree {
   
    int l, r;
    long long int sum, add;
    #define l(x) t[x].l
    #define r(x) t[x].r
    #define sum(x) t[x].sum
    #define add(x) t[x].add
} t [SIZE * 4];

void build(int p, int l, int r) {
   
    l(p) = l, r(p) = r;
    if (l == r) {
   sum(p) = arr[l]; return;}
    int mid = (l + r) / 2;
    build(p * 2, l, mid);
    build(p * 2 + 1, mid + 1, r);
    sum(p) = sum(p * 2) + sum(p * 2 + 1);
    return;
}

void spread(int p) {
   
    if (add(p)) {
   
        sum(p * 2) += add(p) * (r(p * 2) - l(p * 2) + 1);
        sum(p * 2 + 1) += add(p) * (r(p * 2 + 1) - l(p * 2 + 1) + 1);
        add(p * 2) += add(p);
        add(p * 2 + 1) += add(p);
        add(p) = 0;
    }
}

void change(int p, int l, int r, int d) {
   
    if (l <= l(p) && r >= r(p)) {
   
        sum(p) += (long long)d * (r(p) - l(p) + 1);
        add(p) += d;
        return;
    }
    spread(p);
    int mid = (l(p) + r(p)) / 2;
    if (l <= mid) change(p * 2, l, r, d);
    if (mid < r) change(p * 2 + 1, l, r, d);
    sum(p) = sum(p * 2) + sum(p * 2 + 1);
    return;
}

long long ask(int p, int l, int r) {
   
    if (l <= l(p) && r >= r(p)) {
   
        return sum(p);
    }
    spread(p);
    int mid = (l(p) + r(p)) / 2;
    long long val = 0;
    if (l <= mid) val += ask(p * 2, l, r);
    if (mid < r) val += ask(p * 2 + 1, l, r);
    return val;
}

int main() {
   
    int N, M;
    cin >> N >> M;
    for (int i = 1; i <= N; i++) {
   
        cin >> arr[i];
    }
    build(1, 1, N);
    int op, x, y, d;
    for (int i = 0; i < M; i++) {
   
        cin >> op >> x >> y;
        switch(op) {
   
            case 1:{
   cin >> d;change(1, x, y, d);break;}
            case 2:{
   cout << ask(1, x, y) << endl;break;}
        }
    }
    return 0;
}