题目链接

【模板】整数域二分

题目描述

给定一个长度为 的数组 ,需要处理 次查询。

每次查询给定一个区间 ,你需要输出数组 中值大于等于 且小于等于 的元素个数。

解题思路

本题是典型的利用二分查找处理区间查询的问题。如果对每次查询都遍历整个数组,时间复杂度为 ,无法通过本题。

一个高效的解决方案是先对数组进行预处理。我们可以先将数组 进行升序排序。排序后,数组具有单调性,这为使用二分查找创造了条件。

排序后,问题“统计值在 区间内的元素个数”可以转化为一个减法问题: count(a_i in [x, y]) = count(a_i <= y) - count(a_i < x)

这两个子问题都可以通过二分查找高效解决:

  1. count(a_i <= y): 我们需要找到所有小于或等于 的元素。在排序数组中,这等价于找到第一个大于 的元素的位置。这个位置的索引(从0开始)就是小于或等于 的元素的总数。这个操作通常被称为 upper_bound

  2. count(a_i < x): 我们需要找到所有小于 的元素。在排序数组中,这等价于找到第一个大于或等于 的元素的位置。这个位置的索引就是小于 的元素的总数。这个操作通常被称为 lower_bound

因此,最终的算法步骤如下:

  1. 读入数组 并对其进行升序排序。
  2. 对于每一次查询 : a. 使用二分查找(upper_bound)找到第一个大于 的元素的位置 。 b. 使用二分查找(lower_bound)找到第一个大于或等于 的元素的位置 。 c. 查询结果即为

该算法的瓶颈在于初始排序,后续的每次查询都非常快。

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;
    vector<int> a(n);
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
    }

    sort(a.begin(), a.end());

    for (int i = 0; i < m; ++i) {
        int x, y;
        cin >> x >> y;
        auto it_x = lower_bound(a.begin(), a.end(), x);
        auto it_y = upper_bound(a.begin(), a.end(), y);
        cout << distance(it_x, it_y) << endl;
    }

    return 0;
}
import java.util.Arrays;
import java.util.Scanner;

public class Main {
    // 寻找第一个大于等于 target 的元素的索引
    private static int lower_bound(int[] arr, int target) {
        int left = 0, right = arr.length;
        while (left < right) {
            int mid = left + (right - left) / 2;
            if (arr[mid] >= target) {
                right = mid;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }

    // 寻找第一个大于 target 的元素的索引
    private static int upper_bound(int[] arr, int target) {
        int left = 0, right = arr.length;
        while (left < right) {
            int mid = left + (right - left) / 2;
            if (arr[mid] > target) {
                right = mid;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();
        int[] a = new int[n];
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextInt();
        }

        Arrays.sort(a);

        for (int i = 0; i < m; i++) {
            int x = sc.nextInt();
            int y = sc.nextInt();
            int posX = lower_bound(a, x);
            int posY = upper_bound(a, y);
            System.out.println(posY - posX);
        }
    }
}
import bisect

def main():
    n, m = map(int, input().split())
    a = list(map(int, input().split()))

    a.sort()

    for _ in range(m):
        x, y = map(int, input().split())
        
        # bisect_left 找到第一个 >= x 的位置
        pos_x = bisect.bisect_left(a, x)
        
        # bisect_right 找到第一个 > y 的位置
        pos_y = bisect.bisect_right(a, y)
        
        print(pos_y - pos_x)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:排序 + 二分查找。
  • 时间复杂度。其中 是排序数组的时间开销, 是处理 次查询的总时间开销,每次查询的二分查找需要
  • 空间复杂度,用于存储输入的数组