晶体能量阱储量计算

题意

给定一个 的矩阵,每个位置 有一个势能值(可以为负数)。矩阵外部的空间势能恒为

对于矩阵中任意一点,定义逃逸势能为:从该点出发到矩阵外部的所有路径中,路径上最大势能的最小值。该点的储能量 = 逃逸势能 - 该点势能。

求整个矩阵的总储能量。

思路

这道题本质是二维接雨水问题的变体,区别在于:

  1. 势能值可以为负数
  2. 矩阵外部势能为 ,相当于在矩阵四周围了一圈高度为 的"墙"

为什么负值会产生储能?以样例 2 为例: 矩阵只有一个值 ,它直接与外部(势能 )相邻,逃逸势能为 ,储能 = 。相当于外部的"零势能墙"把负势能区域"兜住"了。

算法:优先队列 BFS

经典的二维接雨水思路:从边界向内灌水

  1. 初始化:把矩阵所有边界格子放入最小堆。由于外部势能为 ,边界格子的逃逸势能至少为 ——它必须"经过"外部的 势能才能逃出去。
  2. BFS 扩展:每次从堆中取出逃逸势能最小的格子 ,该格子的储能为 。然后向四个方向扩展,相邻未访问格子的逃逸势能为 ,加入堆中。
  3. 累加所有格子的储能即为答案。

为什么这样做是对的?最小堆保证我们总是先处理"最容易逃出去"的格子。当一个格子从堆中弹出时,它的逃逸势能已经是所有可能路径中的最优值(类似 Dijkstra 的松弛过程)。

复杂度

  • 时间复杂度:,每个格子进出堆一次
  • 空间复杂度:

代码

#include <bits/stdc++.h>
using namespace std;

int main(){
    int m, n;
    scanf("%d%d", &m, &n);
    vector<vector<int>> grid(n, vector<int>(m));
    for(int i = 0; i < n; i++)
        for(int j = 0; j < m; j++)
            scanf("%d", &grid[i][j]);

    priority_queue<tuple<int,int,int>, vector<tuple<int,int,int>>, greater<>> pq;
    vector<vector<bool>> visited(n, vector<bool>(m, false));

    for(int i = 0; i < n; i++){
        for(int j = 0; j < m; j++){
            if(i == 0 || i == n-1 || j == 0 || j == m-1){
                pq.push({max(0, grid[i][j]), i, j});
                visited[i][j] = true;
            }
        }
    }

    int dx[] = {0,0,1,-1};
    int dy[] = {1,-1,0,0};
    long long total = 0;

    while(!pq.empty()){
        auto [h, x, y] = pq.top();
        pq.pop();
        total += h - grid[x][y];
        for(int d = 0; d < 4; d++){
            int nx = x + dx[d], ny = y + dy[d];
            if(nx >= 0 && nx < n && ny >= 0 && ny < m && !visited[nx][ny]){
                visited[nx][ny] = true;
                pq.push({max(h, grid[nx][ny]), nx, ny});
            }
        }
    }

    printf("%lld\n", total);
    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int m = sc.nextInt(), n = sc.nextInt();
        int[][] grid = new int[n][m];
        for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                grid[i][j] = sc.nextInt();

        PriorityQueue<int[]> pq = new PriorityQueue<>((a, b) -> a[0] - b[0]);
        boolean[][] visited = new boolean[n][m];

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                if (i == 0 || i == n - 1 || j == 0 || j == m - 1) {
                    pq.offer(new int[]{Math.max(0, grid[i][j]), i, j});
                    visited[i][j] = true;
                }
            }
        }

        int[] dx = {0, 0, 1, -1};
        int[] dy = {1, -1, 0, 0};
        long total = 0;

        while (!pq.isEmpty()) {
            int[] cur = pq.poll();
            int h = cur[0], x = cur[1], y = cur[2];
            total += h - grid[x][y];
            for (int d = 0; d < 4; d++) {
                int nx = x + dx[d], ny = y + dy[d];
                if (nx >= 0 && nx < n && ny >= 0 && ny < m && !visited[nx][ny]) {
                    visited[nx][ny] = true;
                    pq.offer(new int[]{Math.max(h, grid[nx][ny]), nx, ny});
                }
            }
        }

        System.out.println(total);
    }
}
import heapq
import sys
input = sys.stdin.readline

def main():
    m, n = map(int, input().split())
    grid = []
    for _ in range(n):
        grid.append(list(map(int, input().split())))

    visited = [[False] * m for _ in range(n)]
    heap = []

    for i in range(n):
        for j in range(m):
            if i == 0 or i == n - 1 or j == 0 or j == m - 1:
                heapq.heappush(heap, (max(0, grid[i][j]), i, j))
                visited[i][j] = True

    dx = [0, 0, 1, -1]
    dy = [1, -1, 0, 0]
    total = 0

    while heap:
        h, x, y = heapq.heappop(heap)
        total += h - grid[x][y]
        for d in range(4):
            nx, ny = x + dx[d], y + dy[d]
            if 0 <= nx < n and 0 <= ny < m and not visited[nx][ny]:
                visited[nx][ny] = True
                heapq.heappush(heap, (max(h, grid[nx][ny]), nx, ny))

    print(total)

main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', line => lines.push(line.trim()));
rl.on('close', () => {
    const [m, n] = lines[0].split(' ').map(Number);
    const grid = [];
    for (let i = 0; i < n; i++) {
        grid.push(lines[i + 1].split(' ').map(Number));
    }

    class MinHeap {
        constructor() { this.data = []; }
        push(val) {
            this.data.push(val);
            let i = this.data.length - 1;
            while (i > 0) {
                const p = (i - 1) >> 1;
                if (this.data[p][0] <= this.data[i][0]) break;
                [this.data[p], this.data[i]] = [this.data[i], this.data[p]];
                i = p;
            }
        }
        pop() {
            const top = this.data[0];
            const last = this.data.pop();
            if (this.data.length > 0) {
                this.data[0] = last;
                let i = 0;
                while (true) {
                    let s = i, l = 2*i+1, r = 2*i+2;
                    if (l < this.data.length && this.data[l][0] < this.data[s][0]) s = l;
                    if (r < this.data.length && this.data[r][0] < this.data[s][0]) s = r;
                    if (s === i) break;
                    [this.data[s], this.data[i]] = [this.data[i], this.data[s]];
                    i = s;
                }
            }
            return top;
        }
        get size() { return this.data.length; }
    }

    const visited = Array.from({length: n}, () => new Uint8Array(m));
    const heap = new MinHeap();

    for (let i = 0; i < n; i++) {
        for (let j = 0; j < m; j++) {
            if (i === 0 || i === n-1 || j === 0 || j === m-1) {
                heap.push([Math.max(0, grid[i][j]), i, j]);
                visited[i][j] = 1;
            }
        }
    }

    const dx = [0,0,1,-1], dy = [1,-1,0,0];
    let total = 0;

    while (heap.size > 0) {
        const [h, x, y] = heap.pop();
        total += h - grid[x][y];
        for (let d = 0; d < 4; d++) {
            const nx = x + dx[d], ny = y + dy[d];
            if (nx >= 0 && nx < n && ny >= 0 && ny < m && !visited[nx][ny]) {
                visited[nx][ny] = 1;
                heap.push([Math.max(h, grid[nx][ny]), nx, ny]);
            }
        }
    }

    console.log(total);
});