题目链接
题目描述
给定一批订单记录,每条记录包含订单号、入店时间和离店时间。
再给定一个时间点 A
,需要在这批记录中找出所有满足 入店时间 <= A <= 离店时间
的记录。
要求:
-
单次查询时间复杂度控制在
O(logN)
。 -
输出符合条件的订单号,并按升序排列。
-
如果找不到,输出字符串 "null"。
解题思路
这是一个典型的区间查询问题,要求在指定的时间点 A
找出所有包含该点的区间。O(logN)
的时间复杂度要求提示我们需要使用排序 + 二分查找的策略,而不是简单的线性扫描。
1. 核心思想
一个简单的线性扫描方法是遍历所有记录,检查每个记录的区间 [入店时间, 离店时间]
是否包含 A
。这种方法的时间复杂度为 O(N)
,不满足题目要求。
为了优化查询,我们可以先对数据进行预处理。一个有效的策略是按区间的某个端点进行排序。我们选择按 入店时间
对所有记录进行升序排序。
2. 算法步骤
-
预处理:
-
将所有
N
条订单记录读入一个结构体或类的数组中。 -
对这个数组按照
入店时间
(start_time) 进行升序排序。这一步的时间复杂度是O(N log N)
,在查询之前完成。
-
-
查询:
a. 二分查找确定候选范围:
-
对于给定的查询时间
A
,我们需要找到所有满足入店时间 <= A
的记录。 -
由于数组已经按
入店时间
排序,我们可以使用二分查找(具体来说是upper_bound
的思想)来找到第一个入店时间 > A
的记录的位置。 -
从数组的起始位置到这个位置之前的所有记录,都是满足
入店时间 <= A
的候选记录。这个查找过程的复杂度是O(logN)
。
b. 线性筛选:
-
我们遍历上一步找到的候选记录范围。
-
对于每一条候选记录,我们再检查它是否满足第二个条件:
离店时间 >= A
。 -
将同时满足这两个条件的记录的订单号存入一个结果列表中。
c. 排序并输出:
-
由于题目要求按订单号升序输出,我们需要对结果列表进行排序。
-
如果结果列表为空,输出 "null";否则,按格式输出排序后的订单号。
-
3. 复杂度分析
-
预处理:
O(N log N)
用于排序。 -
查询:
-
二分查找:
O(logN)
。 -
线性筛选:
O(k)
,其中k
是满足入店时间 <= A
的记录数。在最坏情况下(例如A
非常大),k
可能接近N
。 -
结果排序:
O(K log K)
,其中K
是最终符合条件的记录数。
-
-
总查询复杂度:
O(logN + k + K log K)
。虽然在最坏情况下,k
可能导致查询退化为O(N)
,但这种“先二分、后线性”的策略是解决此类问题的标准优化思路之一。要做到严格的O(logN + K)
,需要使用如区间树等更高级的数据结构。
代码
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
struct Record {
int id;
int start_time;
int end_time;
};
bool compareRecords(const Record& a, const Record& b) {
return a.start_time < b.start_time;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
long long query_time;
cin >> query_time;
vector<Record> records(n);
for (int i = 0; i < n; ++i) {
cin >> records[i].id >> records[i].start_time >> records[i].end_time;
}
sort(records.begin(), records.end(), compareRecords);
vector<int> result_ids;
// 遍历所有 start_time <= query_time 的记录
// 这里可以使用 std::upper_bound 优化查找,但后续仍需遍历
// 为了代码清晰,直接遍历
for (const auto& rec : records) {
if (rec.start_time <= query_time) {
if (rec.end_time >= query_time) {
result_ids.push_back(rec.id);
}
} else {
// 因为已经按 start_time 排序,后续的记录 start_time 只会更大
break;
}
}
if (result_ids.empty()) {
cout << "null" << endl;
} else {
sort(result_ids.begin(), result_ids.end());
for (int id : result_ids) {
cout << id << endl;
}
}
return 0;
}
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Scanner;
class Record {
int id;
long startTime;
long endTime;
public Record(int id, long startTime, long endTime) {
this.id = id;
this.startTime = startTime;
this.endTime = endTime;
}
}
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
long queryTime = sc.nextLong();
List<Record> records = new ArrayList<>();
for (int i = 0; i < n; i++) {
records.add(new Record(sc.nextInt(), sc.nextLong(), sc.nextLong()));
}
// 按 start_time 排序
records.sort(Comparator.comparingLong(r -> r.startTime));
List<Integer> resultIds = new ArrayList<>();
// 遍历所有 start_time <= queryTime 的记录
for (Record rec : records) {
if (rec.startTime <= queryTime) {
if (rec.endTime >= queryTime) {
resultIds.add(rec.id);
}
} else {
// 后续记录的 startTime 只会更大,无需继续遍历
break;
}
}
if (resultIds.isEmpty()) {
System.out.println("null");
} else {
Collections.sort(resultIds);
for (int id : resultIds) {
System.out.println(id);
}
}
sc.close();
}
}
import sys
def solve():
try:
n_str = sys.stdin.readline()
if not n_str: return
n = int(n_str)
query_time_str = sys.stdin.readline()
if not query_time_str: return
query_time = int(query_time_str)
records = []
for _ in range(n):
record_data = list(map(int, sys.stdin.readline().split()))
records.append({
"id": record_data[0],
"start": record_data[1],
"end": record_data[2]
})
# 按 start_time 排序
records.sort(key=lambda r: r["start"])
result_ids = []
# 遍历所有 start <= query_time 的记录
for rec in records:
if rec["start"] <= query_time:
if rec["end"] >= query_time:
result_ids.append(rec["id"])
else:
# 后续记录的 start 只会更大
break
if not result_ids:
print("null")
else:
result_ids.sort()
for rec_id in result_ids:
print(rec_id)
except (IOError, ValueError):
return
solve()
算法及复杂度
-
算法:排序 + 二分查找 + 线性筛选
-
时间复杂度: 预处理
O(N log N)
。查询O(logN + k + K log K)
,其中k
是start_time <= A
的记录数,K
是最终结果数。最坏情况下查询为O(N)
。 -
空间复杂度:
O(N)
,用于存储所有记录。