解题思路
-
核心思想:
- 使用状态压缩DP解决,状态表示为 ,其中 表示当前点, 表示已访问点集合
- 表示当前在点 ,已经访问了 表示的点集合时的最小花费
- 通过二进制表示已访问的点集合
-
状态转移:
- 对于当前状态 ,枚举下一个要访问的点
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 205;
const int M = 16;
const int INF = 0x3f3f3f3f;
int n, m, R;
int dist[N][N];
int points[M];
long long dp[M][1 << M];
void floyd() {
for(int k = 1; k <= n; k++) {
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= n; j++) {
if(dist[i][k] != INF && dist[k][j] != INF) {
dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j]);
}
}
}
}
}
int main() {
cin >> n >> m >> R;
for(int i = 0; i < R; i++) {
cin >> points[i];
}
// 初始化距离数组
memset(dist, 0x3f, sizeof(dist));
for(int i = 1; i <= n; i++) {
dist[i][i] = 0;
}
// 读入边
for(int i = 0; i < m; i++) {
int u, v, w;
cin >> u >> v >> w;
dist[u][v] = min(dist[u][v], w);
dist[v][u] = min(dist[v][u], w);
}
// Floyd求最短路
floyd();
// 状态压缩DP
memset(dp, 0x3f, sizeof(dp));
int ALL = (1 << R) - 1;
// 初始化:从任意一点开始
for(int i = 0; i < R; i++) {
dp[i][1 << i] = 0;
}
// 状态转移
for(int state = 0; state < (1 << R); state++) {
for(int i = 0; i < R; i++) {
if(!(state & (1 << i))) continue;
for(int j = 0; j < R; j++) {
if(state & (1 << j)) continue;
int next_state = state | (1 << j);
dp[j][next_state] = min(dp[j][next_state],
dp[i][state] + dist[points[i]][points[j]]);
}
}
}
// 找最小值
long long ans = INF;
for(int i = 0; i < R; i++) {
ans = min(ans, dp[i][ALL]);
}
cout << ans << endl;
return 0;
}
import java.util.*;
public class Main {
static final int N = 205;
static final int M = 16;
static final int INF = 0x3f3f3f3f;
static int n, m, R;
static int[][] dist = new int[N][N];
static int[] points = new int[M];
static long[][] dp = new long[M][1 << M];
static void floyd() {
for(int k = 1; k <= n; k++) {
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= n; j++) {
if(dist[i][k] != INF && dist[k][j] != INF) {
dist[i][j] = Math.min(dist[i][j], dist[i][k] + dist[k][j]);
}
}
}
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
m = sc.nextInt();
R = sc.nextInt();
// 读入需要访问的点
for(int i = 0; i < R; i++) {
points[i] = sc.nextInt();
}
// 初始化距离数组
for(int i = 1; i <= n; i++) {
Arrays.fill(dist[i], INF);
dist[i][i] = 0;
}
// 读入边
for(int i = 0; i < m; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
int w = sc.nextInt();
dist[u][v] = Math.min(dist[u][v], w);
dist[v][u] = Math.min(dist[v][u], w);
}
// Floyd求最短路
floyd();
// 初始化dp数组
for(int i = 0; i < M; i++) {
Arrays.fill(dp[i], INF);
}
// 初始状态:从任意点开始
for(int i = 0; i < R; i++) {
dp[i][1 << i] = 0;
}
// 状态压缩DP
int ALL = (1 << R) - 1;
for(int state = 0; state < (1 << R); state++) {
for(int i = 0; i < R; i++) {
if((state & (1 << i)) == 0) continue; // i不在当前状态中
for(int j = 0; j < R; j++) {
if((state & (1 << j)) != 0) continue; // j已经访问过
int nextState = state | (1 << j);
dp[j][nextState] = Math.min(dp[j][nextState],
dp[i][state] + dist[points[i]][points[j]]);
}
}
}
// 找最小值
long ans = INF;
for(int i = 0; i < R; i++) {
ans = Math.min(ans, dp[i][ALL]);
}
System.out.println(ans);
sc.close();
}
}
import sys
input = sys.stdin.readline # 优化输入
def floyd(n, edges, dist):
"""Floyd算法求任意两点间最短路"""
for k in range(1, n+1):
for i in range(1, n+1):
for j in range(1, n+1):
if dist[i][k] != INF and dist[k][j] != INF:
dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j])
def main():
# 输入
n, m, R = map(int, input().split())
points = list(map(int, input().split()))
# 初始化距离数组
global INF
INF = float('inf')
dist = [[INF] * (n+1) for _ in range(n+1)]
for i in range(1, n+1):
dist[i][i] = 0
# 读入边
for _ in range(m):
u, v, w = map(int, input().split())
dist[u][v] = min(dist[u][v], w)
dist[v][u] = min(dist[v][u], w)
# Floyd求最短路
floyd(n, None, dist)
# 状态压缩DP
ALL = (1 << R) - 1
dp = [[INF] * (1 << R) for _ in range(R)]
# 初始化:从任意点开始
for i in range(R):
dp[i][1 << i] = 0
# 状态转移
for state in range(ALL + 1):
for i in range(R):
if not (state & (1 << i)): # i不在当前状态中
continue
cur = dp[i][state]
if cur == INF:
continue
for j in range(R):
if state & (1 << j): # j已经访问过
continue
next_state = state | (1 << j)
dp[j][next_state] = min(dp[j][next_state],
cur + dist[points[i]][points[j]])
# 找最小值
ans = min(dp[i][ALL] for i in range(R))
print(ans)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:Floyd最短路 + 状态压缩DP
- 时间复杂度:
- Floyd算法:
- 状态压缩DP:
- 空间复杂度: