百度之星初赛第三场B题-最短路2(魔改dijstra)
题目:
小 A 是社团里的工具人,有一天他的朋友给了他一个 n 个点, m 条边的正权连通无向图,要他计算所有点两两之间的最短路。
作为一个工具人,小 A 熟练掌握着 floyd 算法,设 w[i][j]*为原图中$ (i,j)$ 之间的权值最小的边的权值,若没有边则 w[i][j]=无穷大。特别地,若 i=j,则 w[i][j]=0
Floyd 的 C++ 实现如下:
for(int k=1;k<=p;k++)
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
w[i][j]=min(w[i][j],w[i][k]+w[k][j]);
当$ p=n$时,该代码就是我们所熟知的 floyd,然而小 A 为了让代码跑的更快点,所以想减少 p 的值。
令 Di,j为最小的非负整数 x,满足当$ p=x$时,点 i与点 j之间的最短路被正确计算了。
现在你需要求 ∑i=1n∑j=1nDi,j,虽然答案不会很大,但为了显得本题像个计数题,你还是需要将答案对 998244353取模后输出。
Input
第一行一个正整数 T(T≤30) 表示数据组数
对于每组数据:
第一行两个正整数 n,m(1≤n≤1000,m≤2000),表示点数和边数。
保证最多只有 55 组数据满足 max(n,m)>200接下来m行,每行三个正整数 u,v,w描述一条边权为 w的边 (u,v)其中 1≤w≤109输出:
Output
输出 T行,第 i 行一个非负整数表示第 i 组数据的答案
思路:
根据Floyd算法原理可以知道 D[i][j]的值就是 i 到 j 的最短路径中除端点外的最大顶点编号。这里暴力很是不行的。
我们考虑可以用n次dijstra算法,对于第 i次,我们求 i到其他点的最短路。在松弛顶点u的临接边中,对于u的每一条邻接顶点v,如果 dis[v]>dis[u]+w(u,v),就更新 dis[v],另外更新 D[i][v]=max(D[i][u],u)即可。
但是可能求 D[i][j]过程中可能 i到 j 有多条最短路,但不同最短路的最大顶点编号不同,此时我们肯定要取最小的。所以还额外需要加个等最短路的处理。如果 dis[v]==dis[v]+w(u,v) and D[i][v]>max(D[i][u],u).
令 D[i][v]=max(D[i][u],u),并将新的状态 (dis[v],v)加入最小堆。时间复杂度 O(n∗m∗logn)
PS:
注意开long long,INF也要long long 的INF啊喂
比赛中两个小时才做出来,首先没开long long,其次INF不是long long ,另外对于多条最短路没有考虑。比赛要在细心点啊喂
给个数据:
1
5 5
1 4 1
4 5 2
1 2 1
2 3 1
3 5 1
ans: 22
代码:
#include<bits/stdc++.h>
#define mset(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
typedef pair<ll,ll> P;
const ll N=1e3+10;
const ll mod=998244353;
vector<P> g[N];
ll p[N][N];
ll dis[N];
void debug(){
while(true);
}
void dijstr(ll s,ll n)
{
priority_queue<P,vector<P>,greater<P> > Q;
for(int i=1;i<=n;++i) dis[i]=1e18;
dis[s]=0;
for(P &e:g[s])
{
ll v=e.first,w=e.second;
if(dis[v]> dis[s]+w)
{
dis[v]=dis[s]+w;
Q.push({dis[v],v});
}
}
while(!Q.empty())
{
P pp=Q.top();
Q.pop();
if(pp.first>dis[pp.second]) continue;
ll u=pp.second;
for(P &e:g[u])
{
ll v=e.first,w=e.second;
if(dis[v] > dis[u]+w)
{
dis[v]=dis[u]+w;
p[s][v]=max(p[s][u],u);
Q.push({dis[v],v});
}
if((dis[v] ==dis[u]+w) &&p[s][v]>max(p[s][u],u))
{
p[s][v]=max(p[s][u],u);
Q.push({dis[v],v});
}
}
}
}
int main()
{
ios::sync_with_stdio(false);cin.tie(0);
ll t,n,m;
cin>>t;
while(t--)
{
cin>>n>>m;
for(ll i=1; i<=n; ++i) for(ll j=1; j<=n; ++j) p[i][j]=-1;
for(ll i=1; i<=n; ++i)
{
p[i][i]=0;
g[i].clear();
}
for(ll i=1; i<=m; ++i)
{
ll u,v,ww;
cin>>u>>v>>ww;
p[u][v]=p[v][u]=0;
g[u].push_back({v,ww});
g[v].push_back({u,ww});
}
for(ll i=1; i<=n; ++i)
dijstr(i,n);
ll sum=0;
for(ll i=1; i<=n; ++i)
{
for(ll j=1; j<=n; ++j){
sum=(p[i][j]+sum)%mod;
}
}
cout<<sum<<endl;
}
return 0;
}