解题思路

这是一个使用单调栈解决的环形山峰观察问题。关键点:

  1. 单调栈策略

    • 从最高山峰开始遍历
    • 维护单调递减栈
    • 处理相同高度的山峰
  2. 计算可见对数

    • 相同高度山峰之间的对数:
    • 出栈时与栈内元素形成的对数
    • 最后栈内剩余元素的处理

代码

#include <iostream>
#include <stack>
using namespace std;

class Pair {
public:
    int value;
    int times;
    Pair(int v) : value(v), times(1) {}
};

int maxIndex(int arr[], int size) {
    int index = 0;
    for(int i = 1; i < size; i++) {
        if(arr[i] > arr[index]) index = i;
    }
    return index;
}

int nextPosition(int i, int size) {
    return i < size - 1 ? i + 1 : 0;
}

long long getPairWithinSameHeight(int times) {
    return times == 1 ? 0 : (long long)times * (times - 1) / 2;
}

long long solve(int arr[], int size) {
    if(size < 2) return 0;
    
    int maxIdx = maxIndex(arr, size);
    stack<Pair> st;
    st.push(Pair(arr[maxIdx]));
    
    int index = nextPosition(maxIdx, size);
    long long ans = 0;
    
    while(index != maxIdx) {
        int value = arr[index];
        while(!st.empty() && value > st.top().value) {
            int times = st.top().times;
            st.pop();
            ans += getPairWithinSameHeight(times);
            ans += times * 2;
        }
        
        if(!st.empty() && st.top().value == value) {
            st.top().times++;
        } else {
            st.push(Pair(value));
        }
        index = nextPosition(index, size);
    }
    
    while(!st.empty()) {
        int times = st.top().times;
        st.pop();
        ans += getPairWithinSameHeight(times);
        
        if(!st.empty()) {
            ans += times;
            if(st.size() > 1) {
                ans += times;
            } else {
                ans += st.top().times == 1 ? 0 : times;
            }
        }
    }
    
    return ans;
}

int main() {
    int n;
    while(cin >> n) {
        int arr[n];
        for(int i = 0; i < n; i++) {
            cin >> arr[i];
        }
        cout << solve(arr, n) << endl;
    }
    return 0;
}
import java.util.*;

public class Main {
    static class Pair {
        int value;
        int times;
        
        Pair(int value) {
            this.value = value;
            this.times = 1;
        }
    }
    
    static int maxIndex(int[] arr) {
        int index = 0;
        for(int i = 1; i < arr.length; i++) {
            if(arr[i] > arr[index]) index = i;
        }
        return index;
    }
    
    static int nextPosition(int i, int size) {
        return i < size - 1 ? i + 1 : 0;
    }
    
    static long getPairWithinSameHeight(int times) {
        return times == 1 ? 0L : (long)times * (times - 1) / 2L;
    }
    
    static long solve(int[] arr) {
        if(arr.length < 2) return 0;
        
        int size = arr.length;
        int maxIdx = maxIndex(arr);
        Stack<Pair> stack = new Stack<>();
        stack.push(new Pair(arr[maxIdx]));
        
        int index = nextPosition(maxIdx, size);
        long ans = 0;
        
        while(index != maxIdx) {
            int value = arr[index];
            while(!stack.isEmpty() && value > stack.peek().value) {
                int times = stack.pop().times;
                ans += getPairWithinSameHeight(times);
                ans += times * 2;
            }
            
            if(!stack.isEmpty() && stack.peek().value == value) {
                stack.peek().times++;
            } else {
                stack.push(new Pair(value));
            }
            index = nextPosition(index, size);
        }
        
        while(!stack.isEmpty()) {
            int times = stack.pop().times;
            ans += getPairWithinSameHeight(times);
            
            if(!stack.isEmpty()) {
                ans += times;
                if(stack.size() > 1) {
                    ans += times;
                } else {
                    ans += stack.peek().times == 1 ? 0 : times;
                }
            }
        }
        
        return ans;
    }
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        while(sc.hasNext()) {
            int n = sc.nextInt();
            int[] arr = new int[n];
            for(int i = 0; i < n; i++) {
                arr[i] = sc.nextInt();
            }
            System.out.println(solve(arr));
        }
    }
}
class Pair:
    def __init__(self, value):
        self.value = value
        self.times = 1

def max_index(arr):
    return max(range(len(arr)), key=lambda i: arr[i])

def next_position(i, size):
    return i + 1 if i < size - 1 else 0

def get_pair_within_same_height(times):
    return 0 if times == 1 else times * (times - 1) // 2

def solve(arr):
    if len(arr) < 2:
        return 0
    
    size = len(arr)
    max_idx = max_index(arr)
    stack = []
    stack.append(Pair(arr[max_idx]))
    
    index = next_position(max_idx, size)
    ans = 0
    
    while index != max_idx:
        value = arr[index]
        while stack and value > stack[-1].value:
            times = stack.pop().times
            ans += get_pair_within_same_height(times)
            ans += times * 2
        
        if stack and stack[-1].value == value:
            stack[-1].times += 1
        else:
            stack.append(Pair(value))
        index = next_position(index, size)
    
    while stack:
        times = stack.pop().times
        ans += get_pair_within_same_height(times)
        
        if stack:
            ans += times
            if len(stack) > 1:
                ans += times
            else:
                ans += 0 if stack[-1].times == 1 else times
    
    return ans

while True:
    try:
        n = int(input())
        arr = list(map(int, input().split()))
        print(solve(arr))
    except:
        break

算法及复杂度

  • 算法:单调栈
  • 时间复杂度:,每个元素最多入栈出栈一次
  • 空间复杂度:,栈的空间