这个题目不需要题解视频里三遍bfs那么麻烦,实际上只需要一次bfs或者两次最短路就可以。
1.判断某点是否在起点和终点的简单路径上:只需要求出起点和终点之间的距离,起点到该点的距离,终点到该点的距离。将后两者相加看是否等于前者就可以了。
2.求其他点到简单路径的最短距离:只需要将该点到起点和终点距离相加,减去起点和终点之间的距离,之后再除以2就可以了。
以下是我的参考代码: #include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+103;
int n,m,s,dis[maxn],head[maxn],vis[maxn],cnt,ss,diss[maxn],headd[maxn],viss[maxn],cnnt,flag[maxn],ans;
priority_queue<pair<int,int> , vector<pair<int,int> >,greater<pair<int,int> > >q;
priority_queue<pair<int,int> , vector<pair<int,int> >,greater<pair<int,int> > >qq;
int qdd,zdd;
struct edge{
int to,nx,w;//dian bian quanzhi
}edge[maxn];
struct edgee{
int to,nx,w;//dian bian quanzhi
}edgee[maxn];
void add(int u , int v , int w)
{
edge[++cnt].to=v;
edge[cnt].nx=head[u];
head[u]=cnt;
edge[cnt].w=w;
}
void addd(int u , int v , int w)
{
edgee[++cnnt].to=v;
edgee[cnnt].nx=headd[u];
headd[u]=cnnt;
edgee[cnnt].w=w;
}
void dij()
{
for(int i=1;i<=n;i++) dis[i]=2147483644;
dis[s]=0;
q.push(make_pair(0,s));
while(q.size())
{
int zd=q.top().second;
q.pop();
if(vis[zd]==1) continue;
vis[zd]=1;
for(int j=head[zd] ; j ; j=edge[j].nx)
{
int y=edge[j].to ;int z=edge[j].w;
if(dis[zd]+z<dis[y])
{
dis[y]=dis[zd]+z;
q.push(make_pair(dis[y] , y));
}
}
}
}
void dijj()
{
for(int i=1;i<=n;i++) diss[i]=2147483644;
diss[ss]=0;
qq.push(make_pair(0,ss));
while(qq.size())
{
int zdd=qq.top().second;
qq.pop();
if(viss[zdd]==1) continue;
viss[zdd]=1;
for(int j=headd[zdd] ; j ; j=edgee[j].nx)
{
int y=edgee[j].to ;int z=edgee[j].w;
if(diss[zdd]+z<diss[y])
{
diss[y]=diss[zdd]+z;
qq.push(make_pair(diss[y] , y));
}
}
}
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int qu;
cin>>n>>qu;
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
add(u,v,1);
add(v,u,1);
addd(u,v,1);
addd(v,u,1);
}
cin>>s>>ss;
dij();
dijj();
for(int i=1;i<=n;i++)
{
if(dis[i]+diss[i]==dis[ss])
{
flag[i]=1;
}
}
for(int i=1;i<=n;i++)
if(flag[i]!=1)
ans+=(dis[i]+diss[i]-dis[ss])/2;
cout<<ans;
return 0;
}