题目链接
题目描述
在 N
个坑中选择最多 K
个进行种树,要求任意两个种树的坑不能相邻。每个坑 i
对应一个收益 p_i
(可能为负)。目标是最大化总收益。
解题思路
这是一个经典的贪心选择问题,但简单的贪心(如每次选收益最高的坑)会因为选择的局部最优性影响后续选择而失效。我们需要一种能够“反悔”的贪心策略。
核心贪心策略:带悔贪心
我们进行 K
轮选择,每一轮都选择当前能带来的最大收益增量。
-
初始选择:最开始,最大收益增量就是所有坑中收益最高的那一个。假设我们选择了坑
i
,获得了p_i
的收益。 -
选择后的影响:一旦选择了坑
i
,根据规则,其相邻的坑i-1
和i+1
就不能再被选择了。这个选择动作可以看作是将i-1
,i
,i+1
这三个坑“合并”成了一个整体,我们已经从中获取了p_i
的收益。 -
“反悔”机制:现在,如果我们想获得下一次的最大收益增量,选择范围是什么?除了那些独立的、未被影响的坑之外,我们还多了一个新的选项:反悔我们刚才选择
i
的这个决定。- 如果我们反悔,我们就要退回
p_i
的收益。 - 但反悔后,原先不能选的
i-1
和i+1
就被“解锁”了,我们可以选择它们,从而获得p_{i-1} + p_{i+1}
的收益。 - 因此,“反悔并选择
i
的邻居”这个新操作,能带来的收益增量是p_{i-1} + p_{i+1} - p_i
。
- 如果我们反悔,我们就要退回
我们可以把这个“反悔操作”看作一个收益为 p_{i-1} + p_{i+1} - p_i
的新“虚拟”坑,它占据了原来 i-1, i, i+1
的位置。在下一轮选择中,这个虚拟坑将和所有其他可用的坑一起参与竞争。
数据结构与实现
为了高效实现上述过程,我们需要:
- 一个优先队列(最大堆):用于随时获取当前收益最大的可用坑(无论是真实的还是虚拟的)。
- 一个双向链表:用于维护坑之间的邻接关系。当一个坑
i
及其邻居被合并时,我们需要快速找到i-1
的前一个邻居和i+1
的后一个邻居,并将它们连接起来。
完整算法
- 将所有
N
个坑的收益和编号(p_i, i)
放入一个最大堆。 - 用数组
L[]
和R[]
初始化一个双向链表,表示N
个坑的排列。 - 进行
K
次循环: a. 从堆顶取出一个坑i
,其收益为p
。如果该坑已被合并(无效),则跳过并取下一个。 b. 如果p <= 0
,则后续选择不可能增加总收益,直接结束。 c. 将p
加入总收益total_profit
。 d. 从链表中找到i
的左右邻居l = L[i]
和r = R[i]
。 e. 将l
和r
标记为无效(因为它们也被合并了)。 f. 创建一个新的“虚拟”坑,其位置仍然是i
,但收益更新为p_new = p_l + p_r - p
。 g. 更新双向链表,将l
的前驱和r
的后继连接到新的虚拟坑i
上。 h. 将这个新的虚拟坑(p_new, i)
推入堆中。 - 循环结束后,
total_profit
即为最大总收益。
代码
#include <iostream>
#include <vector>
#include <queue>
using namespace std;
struct Node {
long long profit;
int id;
bool operator<(const Node& other) const {
return profit < other.profit;
}
};
const int MAXN = 100005;
long long p[MAXN];
int l[MAXN], r[MAXN];
bool valid[MAXN];
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, k;
cin >> n >> k;
priority_queue<Node> pq;
for (int i = 1; i <= n; ++i) {
cin >> p[i];
l[i] = i - 1;
r[i] = i + 1;
valid[i] = true;
pq.push({p[i], i});
}
r[n] = 0; // Sentinel for right boundary
long long total_profit = 0;
for (int i = 0; i < k; ++i) {
while (!pq.empty() && !valid[pq.top().id]) {
pq.pop();
}
if (pq.empty()) break;
Node top = pq.top();
pq.pop();
if (top.profit <= 0) break;
total_profit += top.profit;
int current_id = top.id;
int left_id = l[current_id];
int right_id = r[current_id];
p[current_id] = p[left_id] + p[right_id] - p[current_id];
valid[left_id] = false;
valid[right_id] = false;
l[current_id] = l[left_id];
r[current_id] = r[right_id];
if (l[current_id] != 0) {
r[l[current_id]] = current_id;
}
if (r[current_id] != 0) {
l[r[current_id]] = current_id;
}
pq.push({p[current_id], current_id});
}
cout << total_profit << endl;
return 0;
}
import java.io.*;
import java.util.*;
public class Main {
static class Node {
long profit;
int id;
Node(long profit, int id) {
this.profit = profit;
this.id = id;
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int n = Integer.parseInt(st.nextToken());
int k = Integer.parseInt(st.nextToken());
long[] p = new long[n + 2];
int[] l = new int[n + 2];
int[] r = new int[n + 2];
boolean[] valid = new boolean[n + 2];
PriorityQueue<Node> pq = new PriorityQueue<>((a, b) -> Long.compare(b.profit, a.profit));
st = new StringTokenizer(br.readLine());
for (int i = 1; i <= n; i++) {
p[i] = Long.parseLong(st.nextToken());
l[i] = i - 1;
r[i] = i + 1;
valid[i] = true;
pq.add(new Node(p[i], i));
}
r[n] = 0; // Sentinel
long totalProfit = 0;
for (int i = 0; i < k; i++) {
while (!pq.isEmpty() && !valid[pq.peek().id]) {
pq.poll();
}
if (pq.isEmpty()) break;
Node top = pq.poll();
if (top.profit <= 0) break;
totalProfit += top.profit;
int currentId = top.id;
int leftId = l[currentId];
int rightId = r[currentId];
p[currentId] = p[leftId] + p[rightId] - p[currentId];
valid[leftId] = false;
valid[rightId] = false;
l[currentId] = l[leftId];
r[currentId] = r[rightId];
if (l[currentId] != 0) {
r[l[currentId]] = currentId;
}
if (r[currentId] != 0) {
l[r[currentId]] = currentId;
}
pq.add(new Node(p[currentId], currentId));
}
PrintWriter out = new PrintWriter(System.out);
out.println(totalProfit);
out.flush();
}
}
import sys
import heapq
def solve():
try:
n_str, k_str = sys.stdin.readline().split()
n, k = int(n_str), int(k_str)
profits_str = sys.stdin.readline().split()
profits = [0] + [int(p) for p in profits_str]
except (IOError, ValueError):
return
# Max-heap (store negative profits in min-heap)
pq = []
l = [i - 1 for i in range(n + 2)]
r = [i + 1 for i in range(n + 2)]
valid = [True] * (n + 2)
for i in range(1, n + 1):
heapq.heappush(pq, (-profits[i], i))
total_profit = 0
for _ in range(k):
while pq and not valid[pq[0][1]]:
heapq.heappop(pq)
if not pq:
break
neg_p, current_id = heapq.heappop(pq)
p = -neg_p
if p <= 0:
break
total_profit += p
left_id = l[current_id]
right_id = r[current_id]
# Invalidate neighbors
valid[left_id] = False
valid[right_id] = False
# Create new "regret" node
profits[current_id] = profits[left_id] + profits[right_id] - profits[current_id]
# Update linked list
l[current_id] = l[left_id]
r[current_id] = r[right_id]
# Connect new neighbors to current node
if l[current_id] != 0:
r[l[current_id]] = current_id
if r[current_id] != 0:
l[r[current_id]] = current_id
heapq.heappush(pq, (-profits[current_id], current_id))
print(total_profit)
solve()
算法及复杂度
- 算法:带悔贪心 + 优先队列 + 双向链表
- 时间复杂度:
- 初始化优先队列为
或
(取决于实现)。
- 循环
K
次,每次循环中,对优先队列的操作(弹出、推入)为。
- 总时间复杂度主要由
K
次堆操作决定。
- 初始化优先队列为
- 空间复杂度:
- 用于存储收益、双向链表数组以及优先队列。