题目链接

能量辐射

题目描述

个能量发射站排成一行,每个发射站 都有一个不相同的高度 和一个能量值 。每个站会同时向左和向右发射能量,能量会被两侧最近的且比它高的发射站接收。请计算接收能量最多的发射站总共接收了多少能量。

解题思路

本题的核心是,对于每个发射站 ,我们需要找到它左侧第一个比它高的发射站(记为 )和右侧第一个比它高的发射站(记为 )。 就是分别接收来自 的左向和右向能量的站。

1. 朴素解法

一个直接的想法是,对每个发射站 ,都向左和向右进行线性扫描,以找到 。这种方法对于每个 都需要 的时间,总时间复杂度为 。对于 的数据规模,该方法会超时。

2. 单调栈优化

“寻找左/右侧第一个更大/更小的元素”是单调栈的经典应用场景。我们可以通过两次遍历,在 的时间内解决这个问题。

我们维护一个存储发射站下标的栈,并使其对应的高度保持单调递减

算法流程:

  1. 初始化:创建一个 received_energy 数组,长度为 ,所有元素初始化为 0。

  2. 第一遍遍历 (从左到右)

    • 此遍遍历的目的是,对于每个发射站 ,找到它右侧第一个比它高的接收站
    • 我们从左到右遍历所有发射站
    • 对于当前站 ,我们不断检查栈顶的站 。如果 ,说明站 是站 右侧第一个比它高的站。因此,站 会接收来自站 的能量。我们更新 received_energy[i] += V[j],然后将 从栈中弹出。
    • 重复此过程直到栈为空或栈顶站 的高度不小于
    • 最后,将当前站 的下标压入栈中。
  3. 第二遍遍历 (从右到左)

    • 此遍遍历的目的是,对于每个发射站 ,找到它左侧第一个比它高的接收站
    • 清空栈,然后从右到左遍历所有发射站
    • 逻辑与第一遍完全对称:如果 ,则站 是站 左侧第一个比它高的站,received_energy[i] += V[j]
  4. 求最终结果: 完成两次遍历后,received_energy 数组中就统计了每个站接收到的总能量。遍历该数组找到最大值即可。

由于每个发射站的下标在两次遍历中都最多入栈和出栈一次,因此总时间复杂度为

代码

#include <iostream>
#include <vector>
#include <stack>
#include <algorithm>

using namespace std;

struct Station {
    int h, v;
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    vector<Station> stations(n);
    for (int i = 0; i < n; ++i) {
        cin >> stations[i].h >> stations[i].v;
    }

    vector<long long> received_energy(n, 0);
    stack<int> s;

    // 第一遍:从左到右,找到每个站右侧的接收者
    for (int i = 0; i < n; ++i) {
        while (!s.empty() && stations[s.top()].h < stations[i].h) {
            received_energy[i] += stations[s.top()].v;
            s.pop();
        }
        s.push(i);
    }

    // 清空栈
    while (!s.empty()) s.pop();

    // 第二遍:从右到左,找到每个站左侧的接收者
    for (int i = n - 1; i >= 0; --i) {
        while (!s.empty() && stations[s.top()].h < stations[i].h) {
            received_energy[i] += stations[s.top()].v;
            s.pop();
        }
        s.push(i);
    }

    long long max_energy = 0;
    for (long long energy : received_energy) {
        max_energy = max(max_energy, energy);
    }

    cout << max_energy << endl;

    return 0;
}
import java.util.Scanner;
import java.util.ArrayDeque;
import java.util.Deque;

public class Main {
    static class Station {
        int h, v;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();

        Station[] stations = new Station[n];
        for (int i = 0; i < n; i++) {
            stations[i] = new Station();
            stations[i].h = sc.nextInt();
            stations[i].v = sc.nextInt();
        }

        long[] receivedEnergy = new long[n];
        Deque<Integer> s = new ArrayDeque<>();

        // 第一遍:从左到右
        for (int i = 0; i < n; i++) {
            while (!s.isEmpty() && stations[s.peekLast()].h < stations[i].h) {
                receivedEnergy[i] += stations[s.pollLast()].v;
            }
            s.offerLast(i);
        }

        s.clear();

        // 第二遍:从右到左
        for (int i = n - 1; i >= 0; i--) {
            while (!s.isEmpty() && stations[s.peekLast()].h < stations[i].h) {
                receivedEnergy[i] += stations[s.pollLast()].v;
            }
            s.offerLast(i);
        }

        long maxEnergy = 0;
        for (long energy : receivedEnergy) {
            maxEnergy = Math.max(maxEnergy, energy);
        }

        System.out.println(maxEnergy);
    }
}
import sys

def solve():
    try:
        n_str = sys.stdin.readline()
        if not n_str: return
        n = int(n_str)
        
        stations = []
        for _ in range(n):
            h, v = map(int, sys.stdin.readline().split())
            stations.append({'h': h, 'v': v})

        received_energy = [0] * n
        s = []

        # 第一遍:从左到右
        for i in range(n):
            while s and stations[s[-1]]['h'] < stations[i]['h']:
                j = s.pop()
                received_energy[i] += stations[j]['v']
            s.append(i)
        
        s.clear()

        # 第二遍:从右到左
        for i in range(n - 1, -1, -1):
            while s and stations[s[-1]]['h'] < stations[i]['h']:
                j = s.pop()
                received_energy[i] += stations[j]['v']
            s.append(i)

        max_energy = 0
        if received_energy:
            max_energy = max(received_energy)
        
        print(max_energy)

    except (IOError, ValueError):
        return

solve()

算法及复杂度

  • 算法: 单调栈
  • 时间复杂度: 我们对输入数组进行了两次线性扫描,每次扫描中,每个元素最多入栈和出栈一次。因此,总的时间复杂度为
  • 空间复杂度: 需要一个数组来存储接收到的能量,以及一个栈。在最坏的情况下(例如,一个严格递减的高度序列),栈中会存储所有 个元素的下标。因此,空间复杂度为