REALHW91 神经网络信号传播预测
题目链接
题目描述
在一个二维神经元矩阵中,每个单元可以是非导电介质(值为 0),也可以是具有特定激活延迟的神经元(值为 1 到 100 的正整数)。
当一个神经元被激活后,它会经历一段等于其激活延迟时间的内部处理。处理完成后,它会立即向其上、下、左、右四个相邻的神经元发送激活信号。
您的任务是,给定一个初始被激活的神经元集合 Sources 和一个特定的目标神经元 T,计算出目标神经元 T 被首次激活的最早时间。如果目标神经元无法被激活,则返回 -1。
解题思路
这是一个在网格图上寻找最短路径的典型问题。我们可以将信号的传播时间视为路径的“成本”或“权重”。由于所有权重(即神经元的激活延迟)都是非负的,解决此类问题的最经典、最高效的算法是 Dijkstra 算法。
1. 问题建模
- 图的节点:矩阵中每一个值为正整数的神经元都可以看作是图中的一个节点。
- 图的边:每个神经元与其上、下、左、右四个相邻的神经元之间存在连接的边。
- 路径成本:信号从一个神经元
A传播到其相邻的神经元B,可以看作是从节点A移动到节点B。这个过程的成本是B自身的激活延迟delay_B。信号到达B的总时间是信号到达A的时间加上A的延迟,再加上B的延迟。但为了Dijkstra算法的应用,我们定义“距离”为神经元完成处理的时间。
2. Dijkstra 算法的应用
- 距离/时间数组: 我们创建一个与输入矩阵大小相同的二维数组
times[M][N],用于存储信号传播到每个神经元并完成其内部处理的最早时间。所有位置的初始值设为无穷大。 - 优先队列 (Min-Heap): 为了能够总是优先处理当前已知“完成时间”最早的神经元,我们使用一个最小优先队列。队列中存储
(time, x, y)形式的元组,代表坐标为(x, y)的神经元在time时刻完成了处理,并准备向其邻居传播信号。 - 初始化:
- 对于每一个初始激活源
(sx, sy),它被激活的时间是 0,完成处理的时间就是它自身的延迟delay_s。 - 因此,我们更新
times[sx][sy] = matrix[sx][sy]。 - 然后,将初始状态
(matrix[sx][sy], sx, sy)压入优先队列。
- 对于每一个初始激活源
- 主循环:
- 从优先队列中弹出拥有最小
time的节点(currentTime, x, y)。 - 如果
currentTime大于times[x][y]中已记录的时间,说明我们之前已经找到了一条更快的路径来处理这个节点,因此跳过当前状态。 - 遍历该节点的四个邻居
(nx, ny)。 - 对于每个合法的、未被更快路径访问过的邻居,计算其完成处理的时间
newTime = currentTime + matrix[nx][ny](即,当前节点完成处理的时刻currentTime,加上邻居节点自身的处理延迟)。 - 如果
newTime小于times[nx][ny]中记录的时间,说明我们找到了一条更快的路径。更新times[nx][ny] = newTime,并将新状态(newTime, nx, ny)压入优先队列。
- 从优先队列中弹出拥有最小
3. 结果处理
- 循环结束后,
times[x][y]存储的是从源点到(x, y)并完成处理的最短时间。目标神经元(tx, ty)被首次激活的时间,是其某个邻居(px, py)完成处理并向它传播信号的时刻。这个时刻就是times[px][py]。而目标神经元自身的处理时间不应计算在内。 - 因此,目标神经元被激活的时间是
times[tx][ty] - matrix[tx][ty]。 - 如果目标位置本身是 0,或者最终计算出的
times[tx][ty]仍为无穷大,则目标不可达,返回 -1。
代码
#include <iostream>
#include <vector>
#include <queue>
#include <limits>
using namespace std;
const long long INF = numeric_limits<long long>::max();
struct State {
long long time;
int r, c;
bool operator>(const State& other) const {
return time > other.time;
}
};
int main() {
int m, n;
cin >> m >> n;
vector<vector<int>> matrix(m, vector<int>(n));
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
cin >> matrix[i][j];
}
}
vector<pair<int, int>> sources;
int r, c;
char peek_char;
while (cin >> r >> c) {
sources.push_back({r, c});
peek_char = cin.peek();
if (peek_char == '\n' || peek_char == EOF) {
break;
}
}
int target_r, target_c;
cin >> target_r >> target_c;
if (matrix[target_r][target_c] == 0) {
cout << -1 << endl;
return 0;
}
vector<vector<long long>> times(m, vector<long long>(n, INF));
priority_queue<State, vector<State>, greater<State>> pq;
for (const auto& src : sources) {
int sr = src.first;
int sc_val = src.second;
if (matrix[sr][sc_val] > 0 && (long long)matrix[sr][sc_val] < times[sr][sc_val]) {
times[sr][sc_val] = matrix[sr][sc_val];
pq.push({(long long)matrix[sr][sc_val], sr, sc_val});
}
}
int dr[] = {-1, 1, 0, 0};
int dc[] = {0, 0, -1, 1};
while (!pq.empty()) {
State current = pq.top();
pq.pop();
long long current_time = current.time;
int cr = current.r;
int cc = current.c;
if (current_time > times[cr][cc]) {
continue;
}
if (cr == target_r && cc == target_c) {
break;
}
for (int i = 0; i < 4; ++i) {
int nr = cr + dr[i];
int nc = cc + dc[i];
if (nr >= 0 && nr < m && nc >= 0 && nc < n && matrix[nr][nc] > 0) {
long long new_time = current_time + matrix[nr][nc];
if (new_time < times[nr][nc]) {
times[nr][nc] = new_time;
pq.push({new_time, nr, nc});
}
}
}
}
long long result_time = times[target_r][target_c];
if (result_time == INF) {
cout << -1 << endl;
} else {
cout << result_time - matrix[target_r][target_c] << endl;
}
return 0;
}
import java.util.*;
public class Main {
static class State implements Comparable<State> {
long time;
int r, c;
public State(long time, int r, int c) {
this.time = time;
this.r = r;
this.c = c;
}
@Override
public int compareTo(State other) {
return Long.compare(this.time, other.time);
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int m = sc.nextInt();
int n = sc.nextInt();
sc.nextLine(); // Consume newline
int[][] matrix = new int[m][n];
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
matrix[i][j] = sc.nextInt();
}
}
sc.nextLine(); // Consume newline
String[] sourceCoords = sc.nextLine().split("\\s+");
List<int[]> sources = new ArrayList<>();
for (int i = 0; i < sourceCoords.length; i += 2) {
sources.add(new int[]{Integer.parseInt(sourceCoords[i]), Integer.parseInt(sourceCoords[i + 1])});
}
int targetR = sc.nextInt();
int targetC = sc.nextInt();
if (matrix[targetR][targetC] == 0) {
System.out.println(-1);
return;
}
long[][] times = new long[m][n];
for (long[] row : times) {
Arrays.fill(row, Long.MAX_VALUE);
}
PriorityQueue<State> pq = new PriorityQueue<>();
for (int[] src : sources) {
int sr = src[0];
int sc_val = src[1];
if (matrix[sr][sc_val] > 0 && (long) matrix[sr][sc_val] < times[sr][sc_val]) {
times[sr][sc_val] = matrix[sr][sc_val];
pq.offer(new State(times[sr][sc_val], sr, sc_val));
}
}
int[] dr = {-1, 1, 0, 0};
int[] dc = {0, 0, -1, 1};
while (!pq.isEmpty()) {
State current = pq.poll();
long currentTime = current.time;
int r = current.r;
int c = current.c;
if (currentTime > times[r][c]) {
continue;
}
if (r == targetR && c == targetC) {
break;
}
for (int i = 0; i < 4; i++) {
int nr = r + dr[i];
int nc = c + dc[i];
if (nr >= 0 && nr < m && nc >= 0 && nc < n && matrix[nr][nc] > 0) {
long newTime = currentTime + matrix[nr][nc];
if (newTime < times[nr][nc]) {
times[nr][nc] = newTime;
pq.offer(new State(newTime, nr, nc));
}
}
}
}
long resultTime = times[targetR][targetC];
if (resultTime == Long.MAX_VALUE) {
System.out.println(-1);
} else {
System.out.println(resultTime - matrix[targetR][targetC]);
}
}
}
import heapq
def solve():
m, n = map(int, input().split())
matrix = [list(map(int, input().split())) for _ in range(m)]
source_coords = list(map(int, input().split()))
sources = []
for i in range(0, len(source_coords), 2):
sources.append((source_coords[i], source_coords[i+1]))
target_r, target_c = map(int, input().split())
if matrix[target_r][target_c] == 0:
print(-1)
return
times = [[float('inf')] * n for _ in range(m)]
pq = []
for r, c_val in sources:
if matrix[r][c_val] > 0:
delay = matrix[r][c_val]
if delay < times[r][c_val]:
times[r][c_val] = delay
heapq.heappush(pq, (delay, r, c_val))
dr = [-1, 1, 0, 0]
dc = [0, 0, -1, 1]
while pq:
current_time, r, c_val = heapq.heappop(pq)
if current_time > times[r][c_val]:
continue
if r == target_r and c_val == target_c:
break
for i in range(4):
nr, nc = r + dr[i], c_val + dc[i]
if 0 <= nr < m and 0 <= nc < n and matrix[nr][nc] > 0:
new_time = current_time + matrix[nr][nc]
if new_time < times[nr][nc]:
times[nr][nc] = new_time
heapq.heappush(pq, (new_time, nr, nc))
result_time = times[target_r][target_c]
if result_time == float('inf'):
print(-1)
else:
# 激活时间是上一个节点完成处理的时间
print(result_time - matrix[target_r][target_c])
solve()
算法及复杂度
- 算法: Dijkstra 算法
- 时间复杂度:
。在网格图中,节点数
,边数
最多为
。使用二叉堆实现的优先队列,Dijkstra 算法的时间复杂度为
,代入后得到
。
- 空间复杂度:
。主要用于存储
times数组,以及在最坏情况下优先队列可能存储所有节点。

京公网安备 11010502036488号