题目链接

查询满足区间的记录

题目描述

给定一批订单记录,每条记录包含订单号、入店时间和离店时间。

再给定一个时间点 A,需要在这批记录中找出所有满足 入店时间 <= A <= 离店时间 的记录。

要求:

  1. 单次查询时间复杂度控制在 O(logN)

  2. 输出符合条件的订单号,并按升序排列。

  3. 如果找不到,输出字符串 "null"。

解题思路

这是一个典型的区间查询问题,要求在指定的时间点 A 找出所有包含该点的区间。O(logN) 的时间复杂度要求提示我们需要使用排序 + 二分查找的策略,而不是简单的线性扫描。

1. 核心思想

一个简单的线性扫描方法是遍历所有记录,检查每个记录的区间 [入店时间, 离店时间] 是否包含 A。这种方法的时间复杂度为 O(N),不满足题目要求。

为了优化查询,我们可以先对数据进行预处理。一个有效的策略是按区间的某个端点进行排序。我们选择按 入店时间 对所有记录进行升序排序。

2. 算法步骤

  1. 预处理

    • 将所有 N 条订单记录读入一个结构体或类的数组中。

    • 对这个数组按照 入店时间 (start_time) 进行升序排序。这一步的时间复杂度是 O(N log N),在查询之前完成。

  2. 查询

    a. 二分查找确定候选范围

    • 对于给定的查询时间 A,我们需要找到所有满足 入店时间 <= A 的记录。

    • 由于数组已经按 入店时间 排序,我们可以使用二分查找(具体来说是 upper_bound 的思想)来找到第一个 入店时间 > A 的记录的位置。

    • 从数组的起始位置到这个位置之前的所有记录,都是满足 入店时间 <= A 的候选记录。这个查找过程的复杂度是 O(logN)

    b. 线性筛选

    • 我们遍历上一步找到的候选记录范围。

    • 对于每一条候选记录,我们再检查它是否满足第二个条件:离店时间 >= A

    • 将同时满足这两个条件的记录的订单号存入一个结果列表中。

    c. 排序并输出

    • 由于题目要求按订单号升序输出,我们需要对结果列表进行排序。

    • 如果结果列表为空,输出 "null";否则,按格式输出排序后的订单号。

3. 复杂度分析

  • 预处理O(N log N) 用于排序。

  • 查询

    • 二分查找:O(logN)

    • 线性筛选:O(k),其中 k 是满足 入店时间 <= A 的记录数。在最坏情况下(例如 A 非常大),k 可能接近 N

    • 结果排序:O(K log K),其中 K 是最终符合条件的记录数。

  • 总查询复杂度O(logN + k + K log K)。虽然在最坏情况下,k 可能导致查询退化为 O(N),但这种“先二分、后线性”的策略是解决此类问题的标准优化思路之一。要做到严格的 O(logN + K),需要使用如区间树等更高级的数据结构。

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

struct Record {
    int id;
    int start_time;
    int end_time;
};

bool compareRecords(const Record& a, const Record& b) {
    return a.start_time < b.start_time;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    long long query_time;
    cin >> query_time;

    vector<Record> records(n);
    for (int i = 0; i < n; ++i) {
        cin >> records[i].id >> records[i].start_time >> records[i].end_time;
    }

    sort(records.begin(), records.end(), compareRecords);

    vector<int> result_ids;
    
    // 遍历所有 start_time <= query_time 的记录
    // 这里可以使用 std::upper_bound 优化查找,但后续仍需遍历
    // 为了代码清晰,直接遍历
    for (const auto& rec : records) {
        if (rec.start_time <= query_time) {
            if (rec.end_time >= query_time) {
                result_ids.push_back(rec.id);
            }
        } else {
            // 因为已经按 start_time 排序,后续的记录 start_time 只会更大
            break;
        }
    }

    if (result_ids.empty()) {
        cout << "null" << endl;
    } else {
        sort(result_ids.begin(), result_ids.end());
        for (int id : result_ids) {
            cout << id << endl;
        }
    }

    return 0;
}
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Scanner;

class Record {
    int id;
    long startTime;
    long endTime;

    public Record(int id, long startTime, long endTime) {
        this.id = id;
        this.startTime = startTime;
        this.endTime = endTime;
    }
}

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        long queryTime = sc.nextLong();

        List<Record> records = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            records.add(new Record(sc.nextInt(), sc.nextLong(), sc.nextLong()));
        }

        // 按 start_time 排序
        records.sort(Comparator.comparingLong(r -> r.startTime));

        List<Integer> resultIds = new ArrayList<>();
        
        // 遍历所有 start_time <= queryTime 的记录
        for (Record rec : records) {
            if (rec.startTime <= queryTime) {
                if (rec.endTime >= queryTime) {
                    resultIds.add(rec.id);
                }
            } else {
                // 后续记录的 startTime 只会更大,无需继续遍历
                break;
            }
        }

        if (resultIds.isEmpty()) {
            System.out.println("null");
        } else {
            Collections.sort(resultIds);
            for (int id : resultIds) {
                System.out.println(id);
            }
        }
        
        sc.close();
    }
}
import sys

def solve():
    try:
        n_str = sys.stdin.readline()
        if not n_str: return
        n = int(n_str)
        
        query_time_str = sys.stdin.readline()
        if not query_time_str: return
        query_time = int(query_time_str)
        
        records = []
        for _ in range(n):
            record_data = list(map(int, sys.stdin.readline().split()))
            records.append({
                "id": record_data[0],
                "start": record_data[1],
                "end": record_data[2]
            })
            
        # 按 start_time 排序
        records.sort(key=lambda r: r["start"])
        
        result_ids = []
        
        # 遍历所有 start <= query_time 的记录
        for rec in records:
            if rec["start"] <= query_time:
                if rec["end"] >= query_time:
                    result_ids.append(rec["id"])
            else:
                # 后续记录的 start 只会更大
                break
                
        if not result_ids:
            print("null")
        else:
            result_ids.sort()
            for rec_id in result_ids:
                print(rec_id)

    except (IOError, ValueError):
        return

solve()

算法及复杂度

  • 算法:排序 + 二分查找 + 线性筛选

  • 时间复杂度: 预处理 O(N log N)。查询 O(logN + k + K log K),其中 kstart_time <= A 的记录数,K 是最终结果数。最坏情况下查询为 O(N)

  • 空间复杂度: O(N),用于存储所有记录。