神经网络信号传播预测
题目分析
给定一个 的二维神经元矩阵,其中 0 表示不导通的材料,正整数 1-9 表示神经元的激活延迟。信号从多个初始激活源同时开始传播,每个神经元被激活后,经过等于其延迟值的时间完成内部处理,然后向上下左右四个相邻神经元发出信号。求目标神经元最早被激活的时间,若不可达则返回 -1。
思路
多源 Dijkstra 最短路
本题本质上是一个带权最短路问题。将每个神经元看作图中的节点,激活延迟看作边权:从节点 到相邻节点的边权为
(即当前神经元的延迟值)。目标是求从任意一个初始激活源到目标位置的最短时间。
具体做法:
- 所有初始激活源的激活时间设为 0,同时加入优先队列(小根堆)。
- 每次从堆中取出当前激活时间最小的神经元
,它在
时刻被激活,经过
的延迟后(即
时刻),向四个方向传播信号。
- 如果相邻神经元
非零且
比之前记录的激活时间更早,就更新并入堆。
- 当目标位置第一次从堆中弹出时,就是最早激活时间。
这就是经典的多源 Dijkstra 算法,因为边权都是正整数(1-9),保证了算法的正确性。
代码
#include <bits/stdc++.h>
using namespace std;
int main() {
int m, n;
scanf("%d%d", &m, &n);
vector<vector<int>> grid(m, vector<int>(n));
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++)
scanf("%d", &grid[i][j]);
string line;
getline(cin, line); // consume newline
getline(cin, line);
vector<pair<int,int>> sources;
istringstream iss(line);
int x, y;
while (iss >> x >> y) {
sources.push_back({x, y});
}
int ta, tb;
scanf("%d%d", &ta, &tb);
if (grid[ta][tb] == 0) {
printf("-1\n");
return 0;
}
vector<vector<long long>> dist(m, vector<long long>(n, LLONG_MAX));
priority_queue<tuple<long long,int,int>, vector<tuple<long long,int,int>>, greater<>> pq;
for (auto& [sx, sy] : sources) {
if (grid[sx][sy] > 0 && dist[sx][sy] > 0) {
dist[sx][sy] = 0;
pq.push({0, sx, sy});
}
}
int dx[] = {-1, 1, 0, 0};
int dy[] = {0, 0, -1, 1};
while (!pq.empty()) {
auto [d, r, c] = pq.top(); pq.pop();
if (d > dist[r][c]) continue;
if (r == ta && c == tb) {
printf("%lld\n", d);
return 0;
}
long long nd = d + grid[r][c];
for (int i = 0; i < 4; i++) {
int nr = r + dx[i], nc = c + dy[i];
if (nr >= 0 && nr < m && nc >= 0 && nc < n && grid[nr][nc] > 0) {
if (nd < dist[nr][nc]) {
dist[nr][nc] = nd;
pq.push({nd, nr, nc});
}
}
}
}
printf("-1\n");
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int m = sc.nextInt(), n = sc.nextInt();
int[][] grid = new int[m][n];
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++)
grid[i][j] = sc.nextInt();
sc.nextLine();
String line = sc.nextLine().trim();
String[] parts = line.split("\\s+");
List<int[]> sources = new ArrayList<>();
for (int i = 0; i + 1 < parts.length; i += 2) {
sources.add(new int[]{Integer.parseInt(parts[i]), Integer.parseInt(parts[i + 1])});
}
int ta = sc.nextInt(), tb = sc.nextInt();
if (grid[ta][tb] == 0) {
System.out.println(-1);
return;
}
long[][] dist = new long[m][n];
for (long[] row : dist) Arrays.fill(row, Long.MAX_VALUE);
PriorityQueue<long[]> pq = new PriorityQueue<>((a, b) -> Long.compare(a[0], b[0]));
for (int[] s : sources) {
if (grid[s[0]][s[1]] > 0 && dist[s[0]][s[1]] > 0) {
dist[s[0]][s[1]] = 0;
pq.offer(new long[]{0, s[0], s[1]});
}
}
int[] dx = {-1, 1, 0, 0};
int[] dy = {0, 0, -1, 1};
while (!pq.isEmpty()) {
long[] cur = pq.poll();
long d = cur[0];
int r = (int) cur[1], c = (int) cur[2];
if (d > dist[r][c]) continue;
if (r == ta && c == tb) {
System.out.println(d);
return;
}
long nd = d + grid[r][c];
for (int i = 0; i < 4; i++) {
int nr = r + dx[i], nc = c + dy[i];
if (nr >= 0 && nr < m && nc >= 0 && nc < n && grid[nr][nc] > 0) {
if (nd < dist[nr][nc]) {
dist[nr][nc] = nd;
pq.offer(new long[]{nd, nr, nc});
}
}
}
}
System.out.println(-1);
}
}
import heapq
import sys
input = sys.stdin.readline
def main():
m, n = map(int, input().split())
grid = []
for _ in range(m):
grid.append(list(map(int, input().split())))
src_line = list(map(int, input().split()))
sources = [(src_line[i], src_line[i + 1]) for i in range(0, len(src_line), 2)]
ta, tb = map(int, input().split())
if grid[ta][tb] == 0:
print(-1)
return
INF = float('inf')
dist = [[INF] * n for _ in range(m)]
heap = []
for sx, sy in sources:
if grid[sx][sy] > 0 and dist[sx][sy] > 0:
dist[sx][sy] = 0
heapq.heappush(heap, (0, sx, sy))
dirs = [(-1, 0), (1, 0), (0, -1), (0, 1)]
while heap:
d, r, c = heapq.heappop(heap)
if d > dist[r][c]:
continue
if r == ta and c == tb:
print(d)
return
nd = d + grid[r][c]
for dr, dc in dirs:
nr, nc = r + dr, c + dc
if 0 <= nr < m and 0 <= nc < n and grid[nr][nc] > 0:
if nd < dist[nr][nc]:
dist[nr][nc] = nd
heapq.heappush(heap, (nd, nr, nc))
print(-1)
main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l.trim()));
rl.on('close', () => {
let idx = 0;
const [m, n] = lines[idx++].split(' ').map(Number);
const grid = [];
for (let i = 0; i < m; i++) {
grid.push(lines[idx++].split(' ').map(Number));
}
const srcParts = lines[idx++].split(' ').map(Number);
const sources = [];
for (let i = 0; i + 1 < srcParts.length; i += 2) {
sources.push([srcParts[i], srcParts[i + 1]]);
}
const [ta, tb] = lines[idx++].split(' ').map(Number);
if (grid[ta][tb] === 0) {
console.log(-1);
return;
}
const dist = Array.from({ length: m }, () => new Array(n).fill(Infinity));
class MinHeap {
constructor() { this.data = []; }
push(item) {
this.data.push(item);
let i = this.data.length - 1;
while (i > 0) {
const p = (i - 1) >> 1;
if (this.data[p][0] <= this.data[i][0]) break;
[this.data[p], this.data[i]] = [this.data[i], this.data[p]];
i = p;
}
}
pop() {
const top = this.data[0];
const last = this.data.pop();
if (this.data.length > 0) {
this.data[0] = last;
let i = 0;
while (true) {
let smallest = i;
const l = 2 * i + 1, r = 2 * i + 2;
if (l < this.data.length && this.data[l][0] < this.data[smallest][0]) smallest = l;
if (r < this.data.length && this.data[r][0] < this.data[smallest][0]) smallest = r;
if (smallest === i) break;
[this.data[smallest], this.data[i]] = [this.data[i], this.data[smallest]];
i = smallest;
}
}
return top;
}
get size() { return this.data.length; }
}
const pq = new MinHeap();
for (const [sx, sy] of sources) {
if (grid[sx][sy] > 0 && dist[sx][sy] > 0) {
dist[sx][sy] = 0;
pq.push([0, sx, sy]);
}
}
const dx = [-1, 1, 0, 0];
const dy = [0, 0, -1, 1];
while (pq.size > 0) {
const [d, r, c] = pq.pop();
if (d > dist[r][c]) continue;
if (r === ta && c === tb) {
console.log(d);
return;
}
const nd = d + grid[r][c];
for (let i = 0; i < 4; i++) {
const nr = r + dx[i], nc = c + dy[i];
if (nr >= 0 && nr < m && nc >= 0 && nc < n && grid[nr][nc] > 0) {
if (nd < dist[nr][nc]) {
dist[nr][nc] = nd;
pq.push([nd, nr, nc]);
}
}
}
}
console.log(-1);
});
复杂度分析
- 时间复杂度:
,其中
和
是矩阵的行数和列数。每个节点最多入堆一次(由 dist 数组剪枝),每次堆操作
。
- 空间复杂度:
,用于存储距离数组和优先队列。

京公网安备 11010502036488号