解题思路

使用差分数组优化:

  1. 对于每个路由器 ,它能覆盖的范围是
  2. 在差分数组中:
    • 位置
    • 位置
  3. 最后通过前缀和还原,得到每个位置收到的信号数量

代码

#include <iostream>
#include <vector>
using namespace std;

int main() {
    int n, k;
    cin >> n >> k;
    
    // 差分数组
    vector<int> diff(n, 0);
    
    // 处理每个路由器的信号范围
    for(int i = 0; i < n; i++) {
        int signal;
        cin >> signal;
        
        // 起始位置
        int start = max(i - signal, 0);
        diff[start]++;
        
        // 结束位置
        if(signal + i + 1 < n) {
            diff[signal + i + 1]--;
        }
    }
    
    // 通过前缀和还原并统计结果
    int sum = 0, result = 0;
    for(int i = 0; i < n; i++) {
        sum += diff[i];
        if(sum >= k) {
            result++;
        }
    }
    
    cout << result << endl;
    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int k = sc.nextInt();
        
        // 差分数组
        int[] diff = new int[n];
        
        // 处理每个路由器的信号范围
        for(int i = 0; i < n; i++) {
            int signal = sc.nextInt();
            
            // 起始位置
            int start = Math.max(i - signal, 0);
            diff[start]++;
            
            // 结束位置
            if(signal + i + 1 < n) {
                diff[signal + i + 1]--;
            }
        }
        
        // 通过前缀和还原并统计结果
        int sum = 0, result = 0;
        for(int i = 0; i < n; i++) {
            sum += diff[i];
            if(sum >= k) {
                result++;
            }
        }
        
        System.out.println(result);
    }
}
def solve():
    n, k = map(int, input().split())
    signals = list(map(int, input().split()))
    
    # 差分数组
    diff = [0] * n
    
    # 处理每个路由器的信号范围
    for i in range(n):
        signal = signals[i]
        
        # 起始位置
        start = max(i - signal, 0)
        diff[start] += 1
        
        # 结束位置
        if signal + i + 1 < n:
            diff[signal + i + 1] -= 1
    
    # 通过前缀和还原并统计结果
    sum_val = 0
    result = 0
    for i in range(n):
        sum_val += diff[i]
        if sum_val >= k:
            result += 1
    
    print(result)

if __name__ == "__main__":
    solve()

算法及复杂度

  • 算法:差分数组
  • 时间复杂度:,只需要遍历两次数组
  • 空间复杂度:,需要存储差分数组