题目链接

单点修改与区间非平凡异或和

题目描述

给定一个长度为 的数组 ,你需要实现一个数据结构来支持以下两种操作:

  1. 单点修改:将下标为 的元素 修改为
  2. 区间非平凡异或和查询:查询区间 内所有连续子序列的异或和的异或和。 形式化地,查询 的值。

输入:

  • 第一行包含两个整数 ,分别表示数组的长度和操作的次数。
  • 第二行包含 个整数,表示数组的初始元素。
  • 接下来 行,每行描述一个操作。格式为 1 x k2 l r

输出:

  • 对于每个查询操作,输出一行,表示对应的查询结果。

解题思路

这道题的难点在于理解和简化“区间非平凡异或和”的查询。

1. 简化查询表达式

我们首先分析在最终的异或和中,每个元素 (其中 ) 贡献了多少次。

会被包含在任何一个以 为起点、以 为终点,且满足 的子序列中。

  • 起点 的选择有 种(从 )。
  • 终点 的选择有 种(从 )。

所以, 在所有子序列中总共出现了 次。

在异或运算中,一个数出现偶数次,结果为 ;出现奇数次,结果为它本身。因此,我们只需要关心出现次数是奇数还是偶数。

贡献到最终结果中,当且仅当 为奇数。

一个乘积为奇数,当且仅当它的所有因子都为奇数。所以,必须同时满足:

  1. 为奇数 为偶数 的奇偶性相同。
  2. 为奇数 为偶数 的奇偶性相同。

综合以上两点,我们得出结论: 被包含在最终的异或和中,当且仅当 的奇偶性与 的奇偶性都相同。

这引出了对查询区间的分析:

  • 如果 的奇偶性不同:找不到任何一个 能同时满足与 奇偶性都相同。因此,没有任何元素会被计入,结果为
  • 如果 的奇偶性相同:我们需要求出所有在 区间内,且与 (和 ) 奇偶性相同的 对应的 的异或和。例如,如果 都是奇数,则查询结果是

2. 数据结构选择

问题转化为了:单点修改,以及对区间内所有奇数(或偶数)下标的元素求异或和。

我们可以将原数组按奇偶下标分成两个独立的数组,并对这两个数组分别建立树状数组(Fenwick Tree) 来维护区间异或和。

  • tree_odd:维护所有奇数下标的元素。原数组中下标为 的元素,对应到这个树状数组中的下标是
  • tree_even:维护所有偶数下标的元素。原数组中下标为 的元素,对应到这个树状数组中的下标是

操作实现

  • 单点修改 1 x k:
    1. 获取 的旧值 old_val
    2. 计算需要更新的值 delta = old_val \oplus k
    3. 如果 是奇数,则在 tree_odd 的对应位置 (x+1)/2 更新 delta
    4. 如果 是偶数,则在 tree_even 的对应位置 x/2 更新 delta
    5. 更新
  • 区间查询 2 l r:
    1. 检查 的奇偶性。如果不同,输出
    2. 如果都是奇数,在 tree_odd 中查询区间 的异或和。
    3. 如果都是偶数,在 tree_even 中查询区间 的异或和。

代码

#include <iostream>
#include <vector>

using namespace std;

using LL = long long;

vector<LL> tr_odd, tr_even;
vector<LL> a;
int n_odd, n_even;

int lowbit(int x) {
    return x & -x;
}

void add(vector<LL>& tr, int size, int x, LL k) {
    for (int i = x; i <= size; i += lowbit(i)) {
        tr[i] ^= k;
    }
}

LL query(const vector<LL>& tr, int x) {
    LL res = 0;
    for (int i = x; i > 0; i -= lowbit(i)) {
        res ^= tr[i];
    }
    return res;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;

    n_odd = (n + 1) / 2;
    n_even = n / 2;

    a.resize(n + 1);
    tr_odd.resize(n_odd + 1, 0);
    tr_even.resize(n_even + 1, 0);

    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        if (i % 2 != 0) {
            add(tr_odd, n_odd, (i + 1) / 2, a[i]);
        } else {
            add(tr_even, n_even, i / 2, a[i]);
        }
    }

    while (m--) {
        int op;
        cin >> op;
        if (op == 1) {
            int x;
            LL k;
            cin >> x >> k;
            LL delta = a[x] ^ k;
            a[x] = k;
            if (x % 2 != 0) {
                add(tr_odd, n_odd, (x + 1) / 2, delta);
            } else {
                add(tr_even, n_even, x / 2, delta);
            }
        } else {
            int l, r;
            cin >> l >> r;
            if ((l % 2) != (r % 2)) {
                cout << 0 << "\n";
            } else {
                LL ans = 0;
                if (l % 2 != 0) {
                    int l_idx = (l + 1) / 2;
                    int r_idx = (r + 1) / 2;
                    ans = query(tr_odd, r_idx) ^ query(tr_odd, l_idx - 1);
                } else {
                    int l_idx = l / 2;
                    int r_idx = r / 2;
                    ans = query(tr_even, r_idx) ^ query(tr_even, l_idx - 1);
                }
                cout << ans << "\n";
            }
        }
    }

    return 0;
}
import java.util.Scanner;

public class Main {
    static long[] tr_odd, tr_even, a;
    static int n_odd, n_even;

    static int lowbit(int x) {
        return x & -x;
    }

    static void add(long[] tr, int size, int x, long k) {
        for (int i = x; i <= size; i += lowbit(i)) {
            tr[i] ^= k;
        }
    }

    static long query(long[] tr, int x) {
        long res = 0;
        for (int i = x; i > 0; i -= lowbit(i)) {
            res ^= tr[i];
        }
        return res;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();

        n_odd = (n + 1) / 2;
        n_even = n / 2;
        
        a = new long[n + 1];
        tr_odd = new long[n_odd + 1];
        tr_even = new long[n_even + 1];

        for (int i = 1; i <= n; i++) {
            a[i] = sc.nextLong();
            if (i % 2 != 0) {
                add(tr_odd, n_odd, (i + 1) / 2, a[i]);
            } else {
                add(tr_even, n_even, i / 2, a[i]);
            }
        }

        while (m-- > 0) {
            int op = sc.nextInt();
            if (op == 1) {
                int x = sc.nextInt();
                long k = sc.nextLong();
                long delta = a[x] ^ k;
                a[x] = k;
                if (x % 2 != 0) {
                    add(tr_odd, n_odd, (x + 1) / 2, delta);
                } else {
                    add(tr_even, n_even, x / 2, delta);
                }
            } else {
                int l = sc.nextInt();
                int r = sc.nextInt();
                if ((l % 2) != (r % 2)) {
                    System.out.println(0);
                } else {
                    long ans = 0;
                    if (l % 2 != 0) {
                        int l_idx = (l + 1) / 2;
                        int r_idx = (r + 1) / 2;
                        ans = query(tr_odd, r_idx) ^ query(tr_odd, l_idx - 1);
                    } else {
                        int l_idx = l / 2;
                        int r_idx = r / 2;
                        ans = query(tr_even, r_idx) ^ query(tr_even, l_idx - 1);
                    }
                    System.out.println(ans);
                }
            }
        }
    }
}
import sys

def lowbit(x):
    return x & -x

def add(tr, size, x, k):
    while x <= size:
        tr[x] ^= k
        x += lowbit(x)

def query(tr, x):
    res = 0
    while x > 0:
        res ^= tr[x]
        x -= lowbit(x)
    return res

# 读取所有输入
lines = sys.stdin.readlines()
line_idx = 0

n, m = map(int, lines[line_idx].split())
line_idx += 1
arr_input = list(map(int, lines[line_idx].split()))
line_idx += 1

n_odd = (n + 1) // 2
n_even = n // 2

a = [0] * (n + 1)
tr_odd = [0] * (n_odd + 1)
tr_even = [0] * (n_even + 1)

for i in range(1, n + 1):
    a[i] = arr_input[i - 1]
    if i % 2 != 0:
        add(tr_odd, n_odd, (i + 1) // 2, a[i])
    else:
        add(tr_even, n_even, i // 2, a[i])

for _ in range(m):
    parts = list(map(int, lines[line_idx].split()))
    line_idx += 1
    op = parts[0]

    if op == 1:
        x, k = parts[1], parts[2]
        delta = a[x] ^ k
        a[x] = k
        if x % 2 != 0:
            add(tr_odd, n_odd, (x + 1) // 2, delta)
        else:
            add(tr_even, n_even, x // 2, delta)
    else:
        l, r = parts[1], parts[2]
        if (l % 2) != (r % 2):
            print(0)
        else:
            ans = 0
            if l % 2 != 0:
                l_idx = (l + 1) // 2
                r_idx = (r + 1) // 2
                ans = query(tr_odd, r_idx) ^ query(tr_odd, l_idx - 1)
            else:
                l_idx = l // 2
                r_idx = r // 2
                ans = query(tr_even, r_idx) ^ query(tr_even, l_idx - 1)
            print(ans)

算法及复杂度

  • 算法:数学 + 树状数组
  • 时间复杂度:初始化需要 。共有 次操作,每次单点修改或区间查询的时间复杂度都为 。因此,总时间复杂度为
  • 空间复杂度:需要一个数组存储原数组,以及两个树状数组,总大小与 线性相关,所以空间复杂度为