REALHW91 神经网络信号传播预测

题目链接

神经网络信号传播预测

题目描述

在一个二维神经元矩阵中,每个单元可以是非导电介质(值为 0),也可以是具有特定激活延迟神经元(值为 1 到 100 的正整数)。

当一个神经元被激活后,它会经历一段等于其激活延迟时间的内部处理。处理完成后,它会立即向其上、下、左、右四个相邻的神经元发送激活信号。

您的任务是,给定一个初始被激活的神经元集合 Sources 和一个特定的目标神经元 T,计算出目标神经元 T首次激活的最早时间。如果目标神经元无法被激活,则返回 -1。

解题思路

这是一个在网格图上寻找最短路径的典型问题。我们可以将信号的传播时间视为路径的“成本”或“权重”。由于所有权重(即神经元的激活延迟)都是非负的,解决此类问题的最经典、最高效的算法是 Dijkstra 算法

1. 问题建模

  • 图的节点:矩阵中每一个值为正整数的神经元都可以看作是图中的一个节点。
  • 图的边:每个神经元与其上、下、左、右四个相邻的神经元之间存在连接的边。
  • 路径成本:信号从一个神经元 A 传播到其相邻的神经元 B,可以看作是从节点 A 移动到节点 B。这个过程的成本是 B 自身的激活延迟 delay_B。信号到达 B 的总时间是信号到达 A 的时间加上 A 的延迟,再加上 B 的延迟。但为了Dijkstra算法的应用,我们定义“距离”为神经元完成处理的时间。

2. Dijkstra 算法的应用

  • 距离/时间数组: 我们创建一个与输入矩阵大小相同的二维数组 times[M][N],用于存储信号传播到每个神经元并完成其内部处理的最早时间。所有位置的初始值设为无穷大。
  • 优先队列 (Min-Heap): 为了能够总是优先处理当前已知“完成时间”最早的神经元,我们使用一个最小优先队列。队列中存储 (time, x, y) 形式的元组,代表坐标为 (x, y) 的神经元在 time 时刻完成了处理,并准备向其邻居传播信号。
  • 初始化:
    1. 对于每一个初始激活源 (sx, sy),它被激活的时间是 0,完成处理的时间就是它自身的延迟 delay_s
    2. 因此,我们更新 times[sx][sy] = matrix[sx][sy]
    3. 然后,将初始状态 (matrix[sx][sy], sx, sy) 压入优先队列。
  • 主循环:
    1. 从优先队列中弹出拥有最小 time 的节点 (currentTime, x, y)
    2. 如果 currentTime 大于 times[x][y] 中已记录的时间,说明我们之前已经找到了一条更快的路径来处理这个节点,因此跳过当前状态。
    3. 遍历该节点的四个邻居 (nx, ny)
    4. 对于每个合法的、未被更快路径访问过的邻居,计算其完成处理的时间 newTime = currentTime + matrix[nx][ny](即,当前节点完成处理的时刻 currentTime,加上邻居节点自身的处理延迟)。
    5. 如果 newTime 小于 times[nx][ny] 中记录的时间,说明我们找到了一条更快的路径。更新 times[nx][ny] = newTime,并将新状态 (newTime, nx, ny) 压入优先队列。

3. 结果处理

  • 循环结束后,times[x][y] 存储的是从源点到 (x, y) 并完成处理的最短时间。目标神经元 (tx, ty)首次激活的时间,是其某个邻居 (px, py) 完成处理并向它传播信号的时刻。这个时刻就是 times[px][py]。而目标神经元自身的处理时间不应计算在内。
  • 因此,目标神经元被激活的时间是 times[tx][ty] - matrix[tx][ty]
  • 如果目标位置本身是 0,或者最终计算出的 times[tx][ty] 仍为无穷大,则目标不可达,返回 -1。

代码

#include <iostream>
#include <vector>
#include <queue>
#include <limits>

using namespace std;

const long long INF = numeric_limits<long long>::max();

struct State {
    long long time;
    int r, c;

    bool operator>(const State& other) const {
        return time > other.time;
    }
};

int main() {
    int m, n;
    cin >> m >> n;

    vector<vector<int>> matrix(m, vector<int>(n));
    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < n; ++j) {
            cin >> matrix[i][j];
        }
    }

    vector<pair<int, int>> sources;
    int r, c;
    char peek_char;
    while (cin >> r >> c) {
        sources.push_back({r, c});
        peek_char = cin.peek();
        if (peek_char == '\n' || peek_char == EOF) {
            break;
        }
    }

    int target_r, target_c;
    cin >> target_r >> target_c;

    if (matrix[target_r][target_c] == 0) {
        cout << -1 << endl;
        return 0;
    }

    vector<vector<long long>> times(m, vector<long long>(n, INF));
    priority_queue<State, vector<State>, greater<State>> pq;

    for (const auto& src : sources) {
        int sr = src.first;
        int sc_val = src.second;
        if (matrix[sr][sc_val] > 0 && (long long)matrix[sr][sc_val] < times[sr][sc_val]) {
            times[sr][sc_val] = matrix[sr][sc_val];
            pq.push({(long long)matrix[sr][sc_val], sr, sc_val});
        }
    }

    int dr[] = {-1, 1, 0, 0};
    int dc[] = {0, 0, -1, 1};

    while (!pq.empty()) {
        State current = pq.top();
        pq.pop();

        long long current_time = current.time;
        int cr = current.r;
        int cc = current.c;

        if (current_time > times[cr][cc]) {
            continue;
        }
        
        if (cr == target_r && cc == target_c) {
            break;
        }

        for (int i = 0; i < 4; ++i) {
            int nr = cr + dr[i];
            int nc = cc + dc[i];

            if (nr >= 0 && nr < m && nc >= 0 && nc < n && matrix[nr][nc] > 0) {
                long long new_time = current_time + matrix[nr][nc];
                if (new_time < times[nr][nc]) {
                    times[nr][nc] = new_time;
                    pq.push({new_time, nr, nc});
                }
            }
        }
    }

    long long result_time = times[target_r][target_c];
    if (result_time == INF) {
        cout << -1 << endl;
    } else {
        cout << result_time - matrix[target_r][target_c] << endl;
    }

    return 0;
}
import java.util.*;

public class Main {
    static class State implements Comparable<State> {
        long time;
        int r, c;

        public State(long time, int r, int c) {
            this.time = time;
            this.r = r;
            this.c = c;
        }

        @Override
        public int compareTo(State other) {
            return Long.compare(this.time, other.time);
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        int m = sc.nextInt();
        int n = sc.nextInt();
        sc.nextLine(); // Consume newline

        int[][] matrix = new int[m][n];
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                matrix[i][j] = sc.nextInt();
            }
        }
        sc.nextLine(); // Consume newline

        String[] sourceCoords = sc.nextLine().split("\\s+");
        List<int[]> sources = new ArrayList<>();
        for (int i = 0; i < sourceCoords.length; i += 2) {
            sources.add(new int[]{Integer.parseInt(sourceCoords[i]), Integer.parseInt(sourceCoords[i + 1])});
        }

        int targetR = sc.nextInt();
        int targetC = sc.nextInt();

        if (matrix[targetR][targetC] == 0) {
            System.out.println(-1);
            return;
        }

        long[][] times = new long[m][n];
        for (long[] row : times) {
            Arrays.fill(row, Long.MAX_VALUE);
        }

        PriorityQueue<State> pq = new PriorityQueue<>();

        for (int[] src : sources) {
            int sr = src[0];
            int sc_val = src[1];
            if (matrix[sr][sc_val] > 0 && (long) matrix[sr][sc_val] < times[sr][sc_val]) {
                times[sr][sc_val] = matrix[sr][sc_val];
                pq.offer(new State(times[sr][sc_val], sr, sc_val));
            }
        }

        int[] dr = {-1, 1, 0, 0};
        int[] dc = {0, 0, -1, 1};

        while (!pq.isEmpty()) {
            State current = pq.poll();
            long currentTime = current.time;
            int r = current.r;
            int c = current.c;

            if (currentTime > times[r][c]) {
                continue;
            }
            if (r == targetR && c == targetC) {
                break;
            }

            for (int i = 0; i < 4; i++) {
                int nr = r + dr[i];
                int nc = c + dc[i];

                if (nr >= 0 && nr < m && nc >= 0 && nc < n && matrix[nr][nc] > 0) {
                    long newTime = currentTime + matrix[nr][nc];
                    if (newTime < times[nr][nc]) {
                        times[nr][nc] = newTime;
                        pq.offer(new State(newTime, nr, nc));
                    }
                }
            }
        }

        long resultTime = times[targetR][targetC];
        if (resultTime == Long.MAX_VALUE) {
            System.out.println(-1);
        } else {
            System.out.println(resultTime - matrix[targetR][targetC]);
        }
    }
}
import heapq

def solve():
    m, n = map(int, input().split())
    matrix = [list(map(int, input().split())) for _ in range(m)]
    
    source_coords = list(map(int, input().split()))
    sources = []
    for i in range(0, len(source_coords), 2):
        sources.append((source_coords[i], source_coords[i+1]))
        
    target_r, target_c = map(int, input().split())

    if matrix[target_r][target_c] == 0:
        print(-1)
        return

    times = [[float('inf')] * n for _ in range(m)]
    pq = []

    for r, c_val in sources:
        if matrix[r][c_val] > 0:
            delay = matrix[r][c_val]
            if delay < times[r][c_val]:
                times[r][c_val] = delay
                heapq.heappush(pq, (delay, r, c_val))

    dr = [-1, 1, 0, 0]
    dc = [0, 0, -1, 1]

    while pq:
        current_time, r, c_val = heapq.heappop(pq)

        if current_time > times[r][c_val]:
            continue
        
        if r == target_r and c_val == target_c:
            break

        for i in range(4):
            nr, nc = r + dr[i], c_val + dc[i]

            if 0 <= nr < m and 0 <= nc < n and matrix[nr][nc] > 0:
                new_time = current_time + matrix[nr][nc]
                if new_time < times[nr][nc]:
                    times[nr][nc] = new_time
                    heapq.heappush(pq, (new_time, nr, nc))
    
    result_time = times[target_r][target_c]
    
    if result_time == float('inf'):
        print(-1)
    else:
        # 激活时间是上一个节点完成处理的时间
        print(result_time - matrix[target_r][target_c])

solve()

算法及复杂度

  • 算法: Dijkstra 算法
  • 时间复杂度: 。在网格图中,节点数 ,边数 最多为 。使用二叉堆实现的优先队列,Dijkstra 算法的时间复杂度为 ,代入后得到
  • 空间复杂度: 。主要用于存储 times 数组,以及在最坏情况下优先队列可能存储所有节点。