题目链接
题目描述
给定一个长度为 的数组
,你需要构建一个数据结构来支持以下两种操作共
次:
- 区间根号:给定区间
,将该区间中的所有元素
修改为其向下取整的平方根,即
。
- 区间和查询:给定区间
,输出该区间中所有元素的和,即
。
数据范围:,
。
解题思路
本题是典型的区间修改与区间查询问题,通常可以采用线段树来解决。标准的区间求和是线段树的常规操作,但“区间开方”这个修改操作比较特殊。
一个数开方后的值与它自身的大小有关,这意味着我们不能像区间加法那样,用一个统一的“懒标记”来表示对整个区间的修改。一个直接的想法是,每次修改都遍历到线段树的叶子节点进行单点更新,但这会导致单次修改的复杂度退化到 ,无法接受。
我们需要寻找“区间开方”操作的特殊性质来进行优化。观察一下一个数在连续开方下的变化:
- 一个大数,如
,在经过几次开方后,数值会迅速减小:
- 一个关键的临界点是,当一个数变成
或
之后,再对它进行开方操作,其值将不再改变(
,
)。
这个性质是优化的突破口。如果一个区间内的所有数都已经变成了 或
,那么对这个区间的“开方”操作就没有任何意义了,我们可以直接跳过它。
基于此,我们可以设计一个特殊的线段树:
- 节点信息:线段树的每个节点除了维护其对应区间的和 (sum) 之外,还额外维护该区间的最大值 (max_val)。
- 建树与查询:
- 建树时,自底向上维护好每个节点的
sum
和max_val
。 - 区间和查询是线段树的标准操作,复杂度为
。
- 建树时,自底向上维护好每个节点的
- 区间开方修改:
- 当我们递归修改一个节点所代表的区间时,首先检查该节点的
max_val
。 - 剪枝优化:如果
node.max_val <= 1
,说明这个区间内所有的数都已经是或
了。此时,开方操作不会改变任何值,我们就可以直接返回,不必再递归深入其子树。这是整个算法能够通过的关键。
- 递归修改:如果
node.max_val > 1
,说明区间内至少还有一个数需要被修改。- 如果当前节点是叶子节点,则直接修改该点的值,并更新其
sum
和max_val
。 - 如果当前节点是内部节点,则递归地对其左右子节点进行修改。修改完子节点后,再用子节点的
sum
和max_val
来更新当前节点的sum
和max_val
。
- 如果当前节点是叶子节点,则直接修改该点的值,并更新其
- 当我们递归修改一个节点所代表的区间时,首先检查该节点的
复杂度分析:
每个数最多只会被有效修改(即值发生改变)几次(对于 大约是 6-7 次)就会变成
。因此,所有数被修改的总次数是有限的,大约是
,其中
是一个很小的常数。每次修改一个叶子节点需要
的时间。因此,所有修改操作的总时间复杂度是近似于
的。
次查询的总复杂度是
。因此,总时间复杂度为
,可以高效地解决本题。
代码
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;
const int MAXN = 100005;
long long a[MAXN];
long long sum[4 * MAXN];
long long max_val[4 * MAXN];
void push_up(int node) {
sum[node] = sum[2 * node] + sum[2 * node + 1];
max_val[node] = max(max_val[2 * node], max_val[2 * node + 1]);
}
void build(int node, int start, int end) {
if (start == end) {
sum[node] = a[start];
max_val[node] = a[start];
return;
}
int mid = (start + end) / 2;
build(2 * node, start, mid);
build(2 * node + 1, mid + 1, end);
push_up(node);
}
void update(int node, int start, int end, int l, int r) {
if (max_val[node] <= 1) { // 关键剪枝
return;
}
if (start == end) {
sum[node] = sqrt(sum[node]);
max_val[node] = sum[node];
return;
}
int mid = (start + end) / 2;
if (l <= mid) {
update(2 * node, start, mid, l, r);
}
if (r > mid) {
update(2 * node + 1, mid + 1, end, l, r);
}
push_up(node);
}
long long query(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return 0;
}
if (l <= start && end <= r) {
return sum[node];
}
int mid = (start + end) / 2;
long long p1 = query(2 * node, start, mid, l, r);
long long p2 = query(2 * node + 1, mid + 1, end, l, r);
return p1 + p2;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
build(1, 1, n);
for (int i = 0; i < m; ++i) {
int type, l, r;
cin >> type >> l >> r;
if (type == 1) {
update(1, 1, n, l, r);
} else {
cout << query(1, 1, n, l, r) << "\n";
}
}
return 0;
}
import java.io.*;
import java.util.StringTokenizer;
import java.lang.Math;
public class Main {
static long[] a;
static long[] sum;
static long[] maxVal;
static void pushUp(int node) {
sum[node] = sum[2 * node] + sum[2 * node + 1];
maxVal[node] = Math.max(maxVal[2 * node], maxVal[2 * node + 1]);
}
static void build(int node, int start, int end) {
if (start == end) {
sum[node] = a[start];
maxVal[node] = a[start];
return;
}
int mid = (start + end) / 2;
build(2 * node, start, mid);
build(2 * node + 1, mid + 1, end);
pushUp(node);
}
static void update(int node, int start, int end, int l, int r) {
if (maxVal[node] <= 1) { // 关键剪枝
return;
}
if (start == end) {
sum[node] = (long) Math.sqrt(sum[node]);
maxVal[node] = sum[node];
return;
}
int mid = (start + end) / 2;
if (l <= mid) {
update(2 * node, start, mid, l, r);
}
if (r > mid) {
update(2 * node + 1, mid + 1, end, l, r);
}
pushUp(node);
}
static long query(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return 0;
}
if (l <= start && end <= r) {
return sum[node];
}
int mid = (start + end) / 2;
long p1 = query(2 * node, start, mid, l, r);
long p2 = query(2 * node + 1, mid + 1, end, l, r);
return p1 + p2;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());
int m = Integer.parseInt(st.nextToken());
a = new long[n + 1];
sum = new long[4 * (n + 1)];
maxVal = new long[4 * (n + 1)];
st = new StringTokenizer(br.readLine());
for (int i = 1; i <= n; i++) {
a[i] = Long.parseLong(st.nextToken());
}
build(1, 1, n);
PrintWriter out = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
for (int i = 0; i < m; i++) {
st = new StringTokenizer(br.readLine());
int type = Integer.parseInt(st.nextToken());
int l = Integer.parseInt(st.nextToken());
int r = Integer.parseInt(st.nextToken());
if (type == 1) {
update(1, 1, n, l, r);
} else {
out.println(query(1, 1, n, l, r));
}
}
out.flush();
}
}
import sys
import math
# 增加递归深度以适应线段树操作
sys.setrecursionlimit(200000)
def push_up(node):
sum_tree[node] = sum_tree[2 * node] + sum_tree[2 * node + 1]
max_tree[node] = max(max_tree[2 * node], max_tree[2 * node + 1])
def build(node, start, end):
if start == end:
sum_tree[node] = arr[start]
max_tree[node] = arr[start]
return
mid = (start + end) // 2
build(2 * node, start, mid)
build(2 * node + 1, mid + 1, end)
push_up(node)
def update(node, start, end, l, r):
# 关键剪枝:如果区间最大值已经小于等于1,无需再进行开方操作
if max_tree[node] <= 1:
return
if start == end:
# 遵照用户要求,不使用 math.isqrt
val = int(math.sqrt(sum_tree[node]))
sum_tree[node] = val
max_tree[node] = val
return
mid = (start + end) // 2
if l <= mid:
update(2 * node, start, mid, l, r)
if r > mid:
update(2 * node + 1, mid + 1, end, l, r)
push_up(node)
def query(node, start, end, l, r):
if r < start or end < l:
return 0
if l <= start and end <= r:
return sum_tree[node]
mid = (start + end) // 2
p1 = query(2 * node, start, mid, l, r)
p2 = query(2 * node + 1, mid + 1, end, l, r)
return p1 + p2
def solve():
global arr, sum_tree, max_tree
# 使用 readline 以优化I/O
lines = sys.stdin.readlines()
n, m = map(int, lines[0].split())
arr = [0] + list(map(int, lines[1].split()))
sum_tree = [0] * (4 * (n + 1))
max_tree = [0] * (4 * (n + 1))
build(1, 1, n)
results = []
for line in lines[2:]:
parts = list(map(int, line.split()))
op, l, r = parts[0], parts[1], parts[2]
if op == 1:
update(1, 1, n, l, r)
else:
results.append(str(query(1, 1, n, l, r)))
print("\n".join(results))
solve()
算法及复杂度
- 算法:带剪枝优化的线段树
- 时间复杂度:建树为
。对于修改操作,由于每个元素最多被有效修改(值变小)的次数是一个非常小的常数(约 6-7 次),所有修改操作的均摊总时间复杂度近似为
。查询操作为
。因此,总时间复杂度为
。
- 空间复杂度:
,用于存储线段树。