题目链接
题目描述
给定一个长度为 的数组
,你需要实现一个数据结构来支持以下两种操作:
- 单点修改:将下标为
的元素
修改为
。
- 区间非平凡异或和查询:查询区间
内所有连续子序列的异或和的异或和。 形式化地,查询
的值。
输入:
- 第一行包含两个整数
和
,分别表示数组的长度和操作的次数。
- 第二行包含
个整数,表示数组的初始元素。
- 接下来
行,每行描述一个操作。格式为
1 x k
或2 l r
。
输出:
- 对于每个查询操作,输出一行,表示对应的查询结果。
解题思路
这道题的难点在于理解和简化“区间非平凡异或和”的查询。
1. 简化查询表达式
我们首先分析在最终的异或和中,每个元素 (其中
) 贡献了多少次。
会被包含在任何一个以
为起点、以
为终点,且满足
的子序列中。
- 起点
的选择有
种(从
到
)。
- 终点
的选择有
种(从
到
)。
所以, 在所有子序列中总共出现了
次。
在异或运算中,一个数出现偶数次,结果为 ;出现奇数次,结果为它本身。因此,我们只需要关心出现次数是奇数还是偶数。
贡献到最终结果中,当且仅当
为奇数。
一个乘积为奇数,当且仅当它的所有因子都为奇数。所以,必须同时满足:
为奇数
为偶数
和
的奇偶性相同。
为奇数
为偶数
和
的奇偶性相同。
综合以上两点,我们得出结论: 被包含在最终的异或和中,当且仅当
的奇偶性与
和
的奇偶性都相同。
这引出了对查询区间的分析:
- 如果
和
的奇偶性不同:找不到任何一个
能同时满足与
和
奇偶性都相同。因此,没有任何元素会被计入,结果为
。
- 如果
和
的奇偶性相同:我们需要求出所有在
区间内,且与
(和
) 奇偶性相同的
对应的
的异或和。例如,如果
都是奇数,则查询结果是
。
2. 数据结构选择
问题转化为了:单点修改,以及对区间内所有奇数(或偶数)下标的元素求异或和。
我们可以将原数组按奇偶下标分成两个独立的数组,并对这两个数组分别建立树状数组(Fenwick Tree) 来维护区间异或和。
tree_odd
:维护所有奇数下标的元素。原数组中下标为的元素,对应到这个树状数组中的下标是
。
tree_even
:维护所有偶数下标的元素。原数组中下标为的元素,对应到这个树状数组中的下标是
。
操作实现:
- 单点修改
1 x k
:- 获取
的旧值
old_val
。 - 计算需要更新的值
delta = old_val \oplus k
。 - 如果
是奇数,则在
tree_odd
的对应位置(x+1)/2
更新delta
。 - 如果
是偶数,则在
tree_even
的对应位置x/2
更新delta
。 - 更新
。
- 获取
- 区间查询
2 l r
:- 检查
和
的奇偶性。如果不同,输出
。
- 如果都是奇数,在
tree_odd
中查询区间的异或和。
- 如果都是偶数,在
tree_even
中查询区间的异或和。
- 检查
代码
#include <iostream>
#include <vector>
using namespace std;
using LL = long long;
vector<LL> tr_odd, tr_even;
vector<LL> a;
int n_odd, n_even;
int lowbit(int x) {
return x & -x;
}
void add(vector<LL>& tr, int size, int x, LL k) {
for (int i = x; i <= size; i += lowbit(i)) {
tr[i] ^= k;
}
}
LL query(const vector<LL>& tr, int x) {
LL res = 0;
for (int i = x; i > 0; i -= lowbit(i)) {
res ^= tr[i];
}
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
int n, m;
cin >> n >> m;
n_odd = (n + 1) / 2;
n_even = n / 2;
a.resize(n + 1);
tr_odd.resize(n_odd + 1, 0);
tr_even.resize(n_even + 1, 0);
for (int i = 1; i <= n; i++) {
cin >> a[i];
if (i % 2 != 0) {
add(tr_odd, n_odd, (i + 1) / 2, a[i]);
} else {
add(tr_even, n_even, i / 2, a[i]);
}
}
while (m--) {
int op;
cin >> op;
if (op == 1) {
int x;
LL k;
cin >> x >> k;
LL delta = a[x] ^ k;
a[x] = k;
if (x % 2 != 0) {
add(tr_odd, n_odd, (x + 1) / 2, delta);
} else {
add(tr_even, n_even, x / 2, delta);
}
} else {
int l, r;
cin >> l >> r;
if ((l % 2) != (r % 2)) {
cout << 0 << "\n";
} else {
LL ans = 0;
if (l % 2 != 0) {
int l_idx = (l + 1) / 2;
int r_idx = (r + 1) / 2;
ans = query(tr_odd, r_idx) ^ query(tr_odd, l_idx - 1);
} else {
int l_idx = l / 2;
int r_idx = r / 2;
ans = query(tr_even, r_idx) ^ query(tr_even, l_idx - 1);
}
cout << ans << "\n";
}
}
}
return 0;
}
import java.util.Scanner;
public class Main {
static long[] tr_odd, tr_even, a;
static int n_odd, n_even;
static int lowbit(int x) {
return x & -x;
}
static void add(long[] tr, int size, int x, long k) {
for (int i = x; i <= size; i += lowbit(i)) {
tr[i] ^= k;
}
}
static long query(long[] tr, int x) {
long res = 0;
for (int i = x; i > 0; i -= lowbit(i)) {
res ^= tr[i];
}
return res;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
n_odd = (n + 1) / 2;
n_even = n / 2;
a = new long[n + 1];
tr_odd = new long[n_odd + 1];
tr_even = new long[n_even + 1];
for (int i = 1; i <= n; i++) {
a[i] = sc.nextLong();
if (i % 2 != 0) {
add(tr_odd, n_odd, (i + 1) / 2, a[i]);
} else {
add(tr_even, n_even, i / 2, a[i]);
}
}
while (m-- > 0) {
int op = sc.nextInt();
if (op == 1) {
int x = sc.nextInt();
long k = sc.nextLong();
long delta = a[x] ^ k;
a[x] = k;
if (x % 2 != 0) {
add(tr_odd, n_odd, (x + 1) / 2, delta);
} else {
add(tr_even, n_even, x / 2, delta);
}
} else {
int l = sc.nextInt();
int r = sc.nextInt();
if ((l % 2) != (r % 2)) {
System.out.println(0);
} else {
long ans = 0;
if (l % 2 != 0) {
int l_idx = (l + 1) / 2;
int r_idx = (r + 1) / 2;
ans = query(tr_odd, r_idx) ^ query(tr_odd, l_idx - 1);
} else {
int l_idx = l / 2;
int r_idx = r / 2;
ans = query(tr_even, r_idx) ^ query(tr_even, l_idx - 1);
}
System.out.println(ans);
}
}
}
}
}
import sys
def lowbit(x):
return x & -x
def add(tr, size, x, k):
while x <= size:
tr[x] ^= k
x += lowbit(x)
def query(tr, x):
res = 0
while x > 0:
res ^= tr[x]
x -= lowbit(x)
return res
# 读取所有输入
lines = sys.stdin.readlines()
line_idx = 0
n, m = map(int, lines[line_idx].split())
line_idx += 1
arr_input = list(map(int, lines[line_idx].split()))
line_idx += 1
n_odd = (n + 1) // 2
n_even = n // 2
a = [0] * (n + 1)
tr_odd = [0] * (n_odd + 1)
tr_even = [0] * (n_even + 1)
for i in range(1, n + 1):
a[i] = arr_input[i - 1]
if i % 2 != 0:
add(tr_odd, n_odd, (i + 1) // 2, a[i])
else:
add(tr_even, n_even, i // 2, a[i])
for _ in range(m):
parts = list(map(int, lines[line_idx].split()))
line_idx += 1
op = parts[0]
if op == 1:
x, k = parts[1], parts[2]
delta = a[x] ^ k
a[x] = k
if x % 2 != 0:
add(tr_odd, n_odd, (x + 1) // 2, delta)
else:
add(tr_even, n_even, x // 2, delta)
else:
l, r = parts[1], parts[2]
if (l % 2) != (r % 2):
print(0)
else:
ans = 0
if l % 2 != 0:
l_idx = (l + 1) // 2
r_idx = (r + 1) // 2
ans = query(tr_odd, r_idx) ^ query(tr_odd, l_idx - 1)
else:
l_idx = l // 2
r_idx = r // 2
ans = query(tr_even, r_idx) ^ query(tr_even, l_idx - 1)
print(ans)
算法及复杂度
- 算法:数学 + 树状数组
- 时间复杂度:初始化需要
。共有
次操作,每次单点修改或区间查询的时间复杂度都为
。因此,总时间复杂度为
。
- 空间复杂度:需要一个数组存储原数组,以及两个树状数组,总大小与
线性相关,所以空间复杂度为
。