题目链接

区间取反与区间数一

题目描述

给定一个长度为 的 01 字符串,你需要实现一个数据结构来支持以下两种操作:

  1. 区间取反:将区间 内的所有字符进行取反操作(0 变为 11 变为 0)。
  2. 区间数一:查询区间 内值为 1 的字符的个数。

输入:

  • 第一行包含两个整数 ,分别表示 01 串的长度和操作的次数。
  • 第二行是一个长度为 的 01 字符串。
  • 接下来 行,每行描述一个操作。格式为 1 l r2 l r

输出:

  • 对于每个查询操作,输出一行,表示对应区间中 1 的个数。

解题思路

本题是典型的数据结构问题,涉及区间修改和区间查询,非常适合使用线段树来解决。

线段树节点设计

对于线段树的每一个节点,我们需要存储以下信息来高效地处理操作:

  1. sum:表示该节点所代表的区间内 1 的个数。
  2. lazy:懒惰标记。lazy = 1 表示该区间需要进行取反操作,lazy = 0 表示无待处理操作。

操作实现

  1. 建树 (build)

    • 从根节点开始,递归地将区间一分为二,直到叶子节点。
    • 叶子节点 [i, i]sum 值就是初始字符串中第 个字符的值 (01)。
    • 非叶子节点的 sum 值是其左右子节点的 sum 值之和。
  2. 懒惰标记下推 (pushdown)

    • 当我们需要访问一个节点的子节点时(在修改或查询操作中),首先检查该节点是否有懒惰标记。
    • 如果 lazy = 1,说明这个区间需要取反。我们将这个标记传递给它的两个子节点(子节点的 lazy 标记也进行异或 1 操作)。
    • 同时,更新子节点的 sum 值。如果一个长度为 len 的区间被取反,那么新的 1 的个数就等于 len - (旧的 1 的个数)
    • 下推完成后,清除当前节点的懒惰标记 (lazy = 0)。
  3. 区间修改 (modify)

    • 当修改区间 [L, R] 时,如果当前节点代表的区间 [l, r] 完全被 [L, R] 覆盖,我们直接更新当前节点的 sum 值,并给它打上懒惰标记,然后返回。
    • 否则,先下推懒惰标记,然后根据 [L, R] 与当前区间中点的关系,递归地到左子树或右子树进行修改。
    • 修改完子节点后,用子节点的 sum 值更新当前节点的 sum 值(pushup)。
  4. 区间查询 (query)

    • 与修改操作类似。如果查询区间 [L, R] 完全覆盖当前节点区间,直接返回当前节点的 sum
    • 否则,先下推懒惰标记,然后递归地到左右子树查询,并将结果汇总返回。

通过这种方式,每次修改和查询操作的时间复杂度都可以控制在

代码

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>

using namespace std;

struct Node {
    int sum;
    int lazy;
};

vector<int> a;
vector<Node> tr;
int n, m;

void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void pushdown(int u, int l, int r) {
    if (tr[u].lazy) {
        int mid = l + r >> 1;
        // 更新左子节点
        tr[u << 1].lazy ^= 1;
        tr[u << 1].sum = (mid - l + 1) - tr[u << 1].sum;
        // 更新右子节点
        tr[u << 1 | 1].lazy ^= 1;
        tr[u << 1 | 1].sum = (r - mid) - tr[u << 1 | 1].sum;
        // 清除当前懒惰标记
        tr[u].lazy = 0;
    }
}

void build(int u, int l, int r) {
    if (l == r) {
        tr[u] = {a[l], 0};
        return;
    }
    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

void modify(int u, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) {
        tr[u].lazy ^= 1;
        tr[u].sum = (r - l + 1) - tr[u].sum;
        return;
    }
    pushdown(u, l, r);
    int mid = l + r >> 1;
    if (ql <= mid) modify(u << 1, l, mid, ql, qr);
    if (qr > mid) modify(u << 1 | 1, mid + 1, r, ql, qr);
    pushup(u);
}

int query(int u, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) {
        return tr[u].sum;
    }
    pushdown(u, l, r);
    int mid = l + r >> 1;
    int res = 0;
    if (ql <= mid) res += query(u << 1, l, mid, ql, qr);
    if (qr > mid) res += query(u << 1 | 1, mid + 1, r, ql, qr);
    return res;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> n >> m;
    a.resize(n + 1);
    tr.resize(n * 4 + 1);

    string s;
    cin >> s;
    for (int i = 0; i < n; i++) {
        a[i + 1] = s[i] - '0';
    }

    build(1, 1, n);

    while (m--) {
        int op, l, r;
        cin >> op >> l >> r;
        if (op == 1) {
            modify(1, 1, n, l, r);
        } else {
            cout << query(1, 1, n, l, r) << "\n";
        }
    }

    return 0;
}
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
import java.util.StringTokenizer;

public class Main {
    static int n, m;
    static int[] a;
    static int[] sum;
    static int[] lazy;

    static void pushup(int u) {
        sum[u] = sum[u << 1] + sum[u << 1 | 1];
    }

    static void pushdown(int u, int l, int r) {
        if (lazy[u] == 1) {
            int mid = l + r >> 1;
            // 更新左子节点
            lazy[u << 1] ^= 1;
            sum[u << 1] = (mid - l + 1) - sum[u << 1];
            // 更新右子节点
            lazy[u << 1 | 1] ^= 1;
            sum[u << 1 | 1] = (r - mid) - sum[u << 1 | 1];
            // 清除当前懒惰标记
            lazy[u] = 0;
        }
    }

    static void build(int u, int l, int r) {
        if (l == r) {
            sum[u] = a[l];
            return;
        }
        int mid = l + r >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }

    static void modify(int u, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) {
            lazy[u] ^= 1;
            sum[u] = (r - l + 1) - sum[u];
            return;
        }
        pushdown(u, l, r);
        int mid = l + r >> 1;
        if (ql <= mid) modify(u << 1, l, mid, ql, qr);
        if (qr > mid) modify(u << 1 | 1, mid + 1, r, ql, qr);
        pushup(u);
    }

    static int query(int u, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) {
            return sum[u];
        }
        pushdown(u, l, r);
        int mid = l + r >> 1;
        int res = 0;
        if (ql <= mid) res += query(u << 1, l, mid, ql, qr);
        if (qr > mid) res += query(u << 1 | 1, mid + 1, r, ql, qr);
        return res;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        m = Integer.parseInt(st.nextToken());

        a = new int[n + 1];
        sum = new int[n * 4 + 1];
        lazy = new int[n * 4 + 1];

        String s = br.readLine();
        for (int i = 0; i < n; i++) {
            a[i + 1] = s.charAt(i) - '0';
        }

        build(1, 1, n);

        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < m; i++) {
            st = new StringTokenizer(br.readLine());
            int op = Integer.parseInt(st.nextToken());
            int l = Integer.parseInt(st.nextToken());
            int r = Integer.parseInt(st.nextToken());
            if (op == 1) {
                modify(1, 1, n, l, r);
            } else {
                sb.append(query(1, 1, n, l, r)).append("\n");
            }
        }
        System.out.print(sb.toString());
    }
}
import sys

def pushup(u):
    tr[u][0] = tr[u << 1][0] + tr[u << 1 | 1][0]

def pushdown(u, l, r):
    if tr[u][1]:
        mid = (l + r) >> 1
        # 更新左子节点
        tr[u << 1][1] ^= 1
        tr[u << 1][0] = (mid - l + 1) - tr[u << 1][0]
        # 更新右子节点
        tr[u << 1 | 1][1] ^= 1
        tr[u << 1 | 1][0] = (r - mid) - tr[u << 1 | 1][0]
        # 清除当前懒惰标记
        tr[u][1] = 0

def build(u, l, r):
    if l == r:
        tr[u] = [a[l], 0]
        return
    mid = (l + r) >> 1
    build(u << 1, l, mid)
    build(u << 1 | 1, mid + 1, r)
    pushup(u)

def modify(u, l, r, ql, qr):
    if ql <= l and r <= qr:
        tr[u][1] ^= 1
        tr[u][0] = (r - l + 1) - tr[u][0]
        return
    pushdown(u, l, r)
    mid = (l + r) >> 1
    if ql <= mid:
        modify(u << 1, l, mid, ql, qr)
    if qr > mid:
        modify(u << 1 | 1, mid + 1, r, ql, qr)
    pushup(u)

def query(u, l, r, ql, qr):
    if ql <= l and r <= qr:
        return tr[u][0]
    pushdown(u, l, r)
    mid = (l + r) >> 1
    res = 0
    if ql <= mid:
        res += query(u << 1, l, mid, ql, qr)
    if qr > mid:
        res += query(u << 1 | 1, mid + 1, r, ql, qr)
    return res

# 读取输入
input = sys.stdin.readline
n, m = map(int, input().split())
s = input().strip()
a = [0] * (n + 1)
for i in range(n):
    a[i + 1] = int(s[i])

# tr[i][0] for sum, tr[i][1] for lazy tag
tr = [[0, 0] for _ in range(n * 4 + 1)]

build(1, 1, n)

for _ in range(m):
    op, l, r = map(int, input().split())
    if op == 1:
        modify(1, 1, n, l, r)
    else:
        print(query(1, 1, n, l, r))

算法及复杂度

  • 算法:线段树 (Segment Tree)
  • 时间复杂度:建树的时间复杂度为 。共有 次操作,每次区间修改和区间查询的时间复杂度均为 。因此,总时间复杂度为
  • 空间复杂度:需要一个数组存储线段树,大小约为 ,所以空间复杂度为