题目链接

PEEK69 区间根号与区间求和

题目描述

给定一个长度为 的数组 ,你需要构建一个数据结构来支持以下两种操作共 次:

  1. 区间根号:给定区间 ,将该区间中的所有元素 修改为其向下取整的平方根,即
  2. 区间和查询:给定区间 ,输出该区间中所有元素的和,即

数据范围:,

解题思路

本题是典型的区间修改与区间查询问题,通常可以采用线段树来解决。标准的区间求和是线段树的常规操作,但“区间开方”这个修改操作比较特殊。

一个数开方后的值与它自身的大小有关,这意味着我们不能像区间加法那样,用一个统一的“懒标记”来表示对整个区间的修改。一个直接的想法是,每次修改都遍历到线段树的叶子节点进行单点更新,但这会导致单次修改的复杂度退化到 ,无法接受。

我们需要寻找“区间开方”操作的特殊性质来进行优化。观察一下一个数在连续开方下的变化:

  • 一个大数,如 ,在经过几次开方后,数值会迅速减小:
  • 一个关键的临界点是,当一个数变成 之后,再对它进行开方操作,其值将不再改变(, )。

这个性质是优化的突破口。如果一个区间内的所有数都已经变成了 ,那么对这个区间的“开方”操作就没有任何意义了,我们可以直接跳过它。

基于此,我们可以设计一个特殊的线段树:

  1. 节点信息:线段树的每个节点除了维护其对应区间的和 (sum) 之外,还额外维护该区间的最大值 (max_val)
  2. 建树与查询
    • 建树时,自底向上维护好每个节点的 summax_val
    • 区间和查询是线段树的标准操作,复杂度为
  3. 区间开方修改
    • 当我们递归修改一个节点所代表的区间时,首先检查该节点的 max_val
    • 剪枝优化:如果 node.max_val <= 1,说明这个区间内所有的数都已经是 了。此时,开方操作不会改变任何值,我们就可以直接返回,不必再递归深入其子树。这是整个算法能够通过的关键。
    • 递归修改:如果 node.max_val > 1,说明区间内至少还有一个数需要被修改。
      • 如果当前节点是叶子节点,则直接修改该点的值,并更新其 summax_val
      • 如果当前节点是内部节点,则递归地对其左右子节点进行修改。修改完子节点后,再用子节点的 summax_val 来更新当前节点的 summax_val

复杂度分析: 每个数最多只会被有效修改(即值发生改变)几次(对于 大约是 6-7 次)就会变成 。因此,所有数被修改的总次数是有限的,大约是 ,其中 是一个很小的常数。每次修改一个叶子节点需要 的时间。因此,所有修改操作的总时间复杂度是近似于 的。 次查询的总复杂度是 。因此,总时间复杂度为 ,可以高效地解决本题。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>

using namespace std;

const int MAXN = 100005;
long long a[MAXN];
long long sum[4 * MAXN];
long long max_val[4 * MAXN];

void push_up(int node) {
    sum[node] = sum[2 * node] + sum[2 * node + 1];
    max_val[node] = max(max_val[2 * node], max_val[2 * node + 1]);
}

void build(int node, int start, int end) {
    if (start == end) {
        sum[node] = a[start];
        max_val[node] = a[start];
        return;
    }
    int mid = (start + end) / 2;
    build(2 * node, start, mid);
    build(2 * node + 1, mid + 1, end);
    push_up(node);
}

void update(int node, int start, int end, int l, int r) {
    if (max_val[node] <= 1) { // 关键剪枝
        return;
    }
    if (start == end) {
        sum[node] = sqrt(sum[node]);
        max_val[node] = sum[node];
        return;
    }
    int mid = (start + end) / 2;
    if (l <= mid) {
        update(2 * node, start, mid, l, r);
    }
    if (r > mid) {
        update(2 * node + 1, mid + 1, end, l, r);
    }
    push_up(node);
}

long long query(int node, int start, int end, int l, int r) {
    if (r < start || end < l) {
        return 0;
    }
    if (l <= start && end <= r) {
        return sum[node];
    }
    int mid = (start + end) / 2;
    long long p1 = query(2 * node, start, mid, l, r);
    long long p2 = query(2 * node + 1, mid + 1, end, l, r);
    return p1 + p2;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;
    for (int i = 1; i <= n; ++i) {
        cin >> a[i];
    }

    build(1, 1, n);

    for (int i = 0; i < m; ++i) {
        int type, l, r;
        cin >> type >> l >> r;
        if (type == 1) {
            update(1, 1, n, l, r);
        } else {
            cout << query(1, 1, n, l, r) << "\n";
        }
    }

    return 0;
}
import java.io.*;
import java.util.StringTokenizer;
import java.lang.Math;

public class Main {
    static long[] a;
    static long[] sum;
    static long[] maxVal;

    static void pushUp(int node) {
        sum[node] = sum[2 * node] + sum[2 * node + 1];
        maxVal[node] = Math.max(maxVal[2 * node], maxVal[2 * node + 1]);
    }

    static void build(int node, int start, int end) {
        if (start == end) {
            sum[node] = a[start];
            maxVal[node] = a[start];
            return;
        }
        int mid = (start + end) / 2;
        build(2 * node, start, mid);
        build(2 * node + 1, mid + 1, end);
        pushUp(node);
    }

    static void update(int node, int start, int end, int l, int r) {
        if (maxVal[node] <= 1) { // 关键剪枝
            return;
        }
        if (start == end) {
            sum[node] = (long) Math.sqrt(sum[node]);
            maxVal[node] = sum[node];
            return;
        }
        int mid = (start + end) / 2;
        if (l <= mid) {
            update(2 * node, start, mid, l, r);
        }
        if (r > mid) {
            update(2 * node + 1, mid + 1, end, l, r);
        }
        pushUp(node);
    }

    static long query(int node, int start, int end, int l, int r) {
        if (r < start || end < l) {
            return 0;
        }
        if (l <= start && end <= r) {
            return sum[node];
        }
        int mid = (start + end) / 2;
        long p1 = query(2 * node, start, mid, l, r);
        long p2 = query(2 * node + 1, mid + 1, end, l, r);
        return p1 + p2;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken());
        int m = Integer.parseInt(st.nextToken());

        a = new long[n + 1];
        sum = new long[4 * (n + 1)];
        maxVal = new long[4 * (n + 1)];
        
        st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; i++) {
            a[i] = Long.parseLong(st.nextToken());
        }

        build(1, 1, n);
        
        PrintWriter out = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
        for (int i = 0; i < m; i++) {
            st = new StringTokenizer(br.readLine());
            int type = Integer.parseInt(st.nextToken());
            int l = Integer.parseInt(st.nextToken());
            int r = Integer.parseInt(st.nextToken());
            if (type == 1) {
                update(1, 1, n, l, r);
            } else {
                out.println(query(1, 1, n, l, r));
            }
        }
        out.flush();
    }
}
import sys
import math

# 增加递归深度以适应线段树操作
sys.setrecursionlimit(200000)

def push_up(node):
    sum_tree[node] = sum_tree[2 * node] + sum_tree[2 * node + 1]
    max_tree[node] = max(max_tree[2 * node], max_tree[2 * node + 1])

def build(node, start, end):
    if start == end:
        sum_tree[node] = arr[start]
        max_tree[node] = arr[start]
        return
    mid = (start + end) // 2
    build(2 * node, start, mid)
    build(2 * node + 1, mid + 1, end)
    push_up(node)

def update(node, start, end, l, r):
    # 关键剪枝:如果区间最大值已经小于等于1,无需再进行开方操作
    if max_tree[node] <= 1:
        return
    if start == end:
        # 遵照用户要求,不使用 math.isqrt
        val = int(math.sqrt(sum_tree[node]))
        sum_tree[node] = val
        max_tree[node] = val
        return
    
    mid = (start + end) // 2
    if l <= mid:
        update(2 * node, start, mid, l, r)
    if r > mid:
        update(2 * node + 1, mid + 1, end, l, r)
    push_up(node)

def query(node, start, end, l, r):
    if r < start or end < l:
        return 0
    if l <= start and end <= r:
        return sum_tree[node]
    
    mid = (start + end) // 2
    p1 = query(2 * node, start, mid, l, r)
    p2 = query(2 * node + 1, mid + 1, end, l, r)
    return p1 + p2

def solve():
    global arr, sum_tree, max_tree
    
    # 使用 readline 以优化I/O
    lines = sys.stdin.readlines()
    n, m = map(int, lines[0].split())
    arr = [0] + list(map(int, lines[1].split()))
    
    sum_tree = [0] * (4 * (n + 1))
    max_tree = [0] * (4 * (n + 1))
    
    build(1, 1, n)
    
    results = []
    for line in lines[2:]:
        parts = list(map(int, line.split()))
        op, l, r = parts[0], parts[1], parts[2]
        if op == 1:
            update(1, 1, n, l, r)
        else:
            results.append(str(query(1, 1, n, l, r)))
            
    print("\n".join(results))

solve()

算法及复杂度

  • 算法:带剪枝优化的线段树
  • 时间复杂度:建树为 。对于修改操作,由于每个元素最多被有效修改(值变小)的次数是一个非常小的常数(约 6-7 次),所有修改操作的均摊总时间复杂度近似为 。查询操作为 。因此,总时间复杂度为
  • 空间复杂度,用于存储线段树。