题目链接

种树

题目描述

N 个坑中选择最多 K 个进行种树,要求任意两个种树的坑不能相邻。每个坑 i 对应一个收益 p_i(可能为负)。目标是最大化总收益。

解题思路

这是一个经典的贪心选择问题,但简单的贪心(如每次选收益最高的坑)会因为选择的局部最优性影响后续选择而失效。我们需要一种能够“反悔”的贪心策略。

核心贪心策略:带悔贪心

我们进行 K 轮选择,每一轮都选择当前能带来的最大收益增量

  1. 初始选择:最开始,最大收益增量就是所有坑中收益最高的那一个。假设我们选择了坑 i,获得了 p_i 的收益。

  2. 选择后的影响:一旦选择了坑 i,根据规则,其相邻的坑 i-1i+1 就不能再被选择了。这个选择动作可以看作是将 i-1, i, i+1 这三个坑“合并”成了一个整体,我们已经从中获取了 p_i 的收益。

  3. “反悔”机制:现在,如果我们想获得下一次的最大收益增量,选择范围是什么?除了那些独立的、未被影响的坑之外,我们还多了一个新的选项:反悔我们刚才选择 i 的这个决定

    • 如果我们反悔,我们就要退回 p_i 的收益。
    • 但反悔后,原先不能选的 i-1i+1 就被“解锁”了,我们可以选择它们,从而获得 p_{i-1} + p_{i+1} 的收益。
    • 因此,“反悔并选择 i 的邻居”这个新操作,能带来的收益增量是 p_{i-1} + p_{i+1} - p_i

我们可以把这个“反悔操作”看作一个收益为 p_{i-1} + p_{i+1} - p_i新“虚拟”坑,它占据了原来 i-1, i, i+1 的位置。在下一轮选择中,这个虚拟坑将和所有其他可用的坑一起参与竞争。

数据结构与实现

为了高效实现上述过程,我们需要:

  1. 一个优先队列(最大堆):用于随时获取当前收益最大的可用坑(无论是真实的还是虚拟的)。
  2. 一个双向链表:用于维护坑之间的邻接关系。当一个坑 i 及其邻居被合并时,我们需要快速找到 i-1 的前一个邻居和 i+1 的后一个邻居,并将它们连接起来。

完整算法

  1. 将所有 N 个坑的收益和编号 (p_i, i) 放入一个最大堆。
  2. 用数组 L[]R[] 初始化一个双向链表,表示 N 个坑的排列。
  3. 进行 K 次循环: a. 从堆顶取出一个坑 i,其收益为 p。如果该坑已被合并(无效),则跳过并取下一个。 b. 如果 p <= 0,则后续选择不可能增加总收益,直接结束。 c. 将 p 加入总收益 total_profit。 d. 从链表中找到 i 的左右邻居 l = L[i]r = R[i]。 e. 将 lr 标记为无效(因为它们也被合并了)。 f. 创建一个新的“虚拟”坑,其位置仍然是 i,但收益更新为 p_new = p_l + p_r - p。 g. 更新双向链表,将 l 的前驱和 r 的后继连接到新的虚拟坑 i 上。 h. 将这个新的虚拟坑 (p_new, i) 推入堆中。
  4. 循环结束后,total_profit 即为最大总收益。

代码

#include <iostream>
#include <vector>
#include <queue>

using namespace std;

struct Node {
    long long profit;
    int id;

    bool operator<(const Node& other) const {
        return profit < other.profit;
    }
};

const int MAXN = 100005;
long long p[MAXN];
int l[MAXN], r[MAXN];
bool valid[MAXN];

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, k;
    cin >> n >> k;

    priority_queue<Node> pq;
    for (int i = 1; i <= n; ++i) {
        cin >> p[i];
        l[i] = i - 1;
        r[i] = i + 1;
        valid[i] = true;
        pq.push({p[i], i});
    }
    r[n] = 0; // Sentinel for right boundary

    long long total_profit = 0;
    for (int i = 0; i < k; ++i) {
        while (!pq.empty() && !valid[pq.top().id]) {
            pq.pop();
        }
        if (pq.empty()) break;

        Node top = pq.top();
        pq.pop();

        if (top.profit <= 0) break;

        total_profit += top.profit;
        
        int current_id = top.id;
        int left_id = l[current_id];
        int right_id = r[current_id];

        p[current_id] = p[left_id] + p[right_id] - p[current_id];
        
        valid[left_id] = false;
        valid[right_id] = false;

        l[current_id] = l[left_id];
        r[current_id] = r[right_id];
        
        if (l[current_id] != 0) {
            r[l[current_id]] = current_id;
        }
        if (r[current_id] != 0) {
            l[r[current_id]] = current_id;
        }
        
        pq.push({p[current_id], current_id});
    }

    cout << total_profit << endl;

    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static class Node {
        long profit;
        int id;

        Node(long profit, int id) {
            this.profit = profit;
            this.id = id;
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken());
        int k = Integer.parseInt(st.nextToken());

        long[] p = new long[n + 2];
        int[] l = new int[n + 2];
        int[] r = new int[n + 2];
        boolean[] valid = new boolean[n + 2];

        PriorityQueue<Node> pq = new PriorityQueue<>((a, b) -> Long.compare(b.profit, a.profit));
        
        st = new StringTokenizer(br.readLine());
        for (int i = 1; i <= n; i++) {
            p[i] = Long.parseLong(st.nextToken());
            l[i] = i - 1;
            r[i] = i + 1;
            valid[i] = true;
            pq.add(new Node(p[i], i));
        }
        r[n] = 0; // Sentinel

        long totalProfit = 0;
        for (int i = 0; i < k; i++) {
            while (!pq.isEmpty() && !valid[pq.peek().id]) {
                pq.poll();
            }
            if (pq.isEmpty()) break;

            Node top = pq.poll();
            if (top.profit <= 0) break;

            totalProfit += top.profit;

            int currentId = top.id;
            int leftId = l[currentId];
            int rightId = r[currentId];
            
            p[currentId] = p[leftId] + p[rightId] - p[currentId];

            valid[leftId] = false;
            valid[rightId] = false;

            l[currentId] = l[leftId];
            r[currentId] = r[rightId];
            
            if (l[currentId] != 0) {
                r[l[currentId]] = currentId;
            }
            if (r[currentId] != 0) {
                l[r[currentId]] = currentId;
            }

            pq.add(new Node(p[currentId], currentId));
        }

        PrintWriter out = new PrintWriter(System.out);
        out.println(totalProfit);
        out.flush();
    }
}
import sys
import heapq

def solve():
    try:
        n_str, k_str = sys.stdin.readline().split()
        n, k = int(n_str), int(k_str)
        profits_str = sys.stdin.readline().split()
        profits = [0] + [int(p) for p in profits_str]
    except (IOError, ValueError):
        return

    # Max-heap (store negative profits in min-heap)
    pq = []
    
    l = [i - 1 for i in range(n + 2)]
    r = [i + 1 for i in range(n + 2)]
    valid = [True] * (n + 2)
    
    for i in range(1, n + 1):
        heapq.heappush(pq, (-profits[i], i))

    total_profit = 0
    for _ in range(k):
        while pq and not valid[pq[0][1]]:
            heapq.heappop(pq)
        
        if not pq:
            break

        neg_p, current_id = heapq.heappop(pq)
        p = -neg_p

        if p <= 0:
            break

        total_profit += p
        
        left_id = l[current_id]
        right_id = r[current_id]

        # Invalidate neighbors
        valid[left_id] = False
        valid[right_id] = False
        
        # Create new "regret" node
        profits[current_id] = profits[left_id] + profits[right_id] - profits[current_id]
        
        # Update linked list
        l[current_id] = l[left_id]
        r[current_id] = r[right_id]
        
        # Connect new neighbors to current node
        if l[current_id] != 0:
            r[l[current_id]] = current_id
        if r[current_id] != 0:
            l[r[current_id]] = current_id
            
        heapq.heappush(pq, (-profits[current_id], current_id))

    print(total_profit)

solve()

算法及复杂度

  • 算法:带悔贪心 + 优先队列 + 双向链表
  • 时间复杂度:
    • 初始化优先队列为 (取决于实现)。
    • 循环 K 次,每次循环中,对优先队列的操作(弹出、推入)为
    • 总时间复杂度主要由 K 次堆操作决定。
  • 空间复杂度:
    • 用于存储收益、双向链表数组以及优先队列。