首先,题目有个条件:n个节点,n-1条路径,所以,这是一棵树。之后,会再在u,v之间加入一条边权为0的边,让我们求图上任意两点的最短路。
若我们忽略新加入的uv边,那直接求lca求距离就行了。但是,现在加入了uv边,但是,我们可以想到,如果真正的最短路径不经过uv边,那么答案依旧还是lca求得的结果,因此,我们可以分两种情况:经过uv边和不经过uv边。
不经过uv边上面已经说过了,若经过uv边,肯定经过u,v中的某一个节点,则,我们可以把这个节点(u和v任选一个)作为中转节点(这里假设选了u),那么,很显然,在经过u的情况下的最短路径对dis[x]+disy,因此,我们可以以u点为原点跑一次单源点最短路,以此求得在经过u点情况下的最短路径。
之后,让通过树上距离求得的距离和最短路跑出来的距离取个min就可以了。
类似题目:https://ac.nowcoder.com/acm/problem/19814,都是先处理树上距离,再处理另外边,只是处理方式不一样。
AC代码
#include <iostream> #include <map> #include <ctime> #include <vector> #include <climits> #include <algorithm> #include <random> #include <cstring> #include <cstdio> #include <map> #include <set> #include <bitset> #include <queue> #define inf 0x3f3f3f3f #define IOS ios_base::sync_with_stdio(0); cin.tie(0); #define rep(i, a, n) for(register int i = a; i <= n; ++ i) #define per(i, a, n) for(register int i = n; i >= a; -- i) #define ONLINE_JUDGE using namespace std; typedef long long ll; const int mod=1e9+7; template<typename T>void write(T x) { if(x<0) { putchar('-'); x=-x; } if(x>9) { write(x/10); } putchar(x%10+'0'); } template<typename T> void read(T &x) { x = 0;char ch = getchar();ll f = 1; while(!isdigit(ch)){if(ch == '-')f*=-1;ch=getchar();} while(isdigit(ch)){x = x*10+ch-48;ch=getchar();}x*=f; } int gcd(int a,int b){return b==0?a:gcd(b,a%b);} int lcm(int a,int b){return a/gcd(a,b)*b;}; ll ksm(ll a,ll n){//看是否要mod ll ans=1; while(n){ if(n&1) ans=((ans%mod)*(a%mod))%mod; a=((a%mod)*(a%mod))%mod; n>>=1; } return ans%mod; } //============================================================== #define int ll const int maxn=3e5+10; struct Edge{ int to,next,w; }e[maxn<<1]; int n; int cnt,head[maxn]; void add(int x,int y,int w){ e[cnt].to=y; e[cnt].next=head[x]; e[cnt].w=w; head[x]=cnt++; } int par[maxn][20],lg[maxn],depth[maxn]; void dfs(int u,int fa,int de){ par[u][0]=fa; depth[u]=de; rep(i,1,lg[de]-1){ par[u][i]=par[par[u][i-1]][i-1]; } for(int i=head[u];~i;i=e[i].next){ int v=e[i].to; if(v==fa) continue; dfs(v,u,de+1); } } int lca(int a,int b){ if(depth[a]<depth[b]) swap(a,b); while(depth[a]>depth[b]) a=par[a][lg[depth[a]-depth[b]]-1]; if(a==b) return a; for(int i=lg[depth[a]]-1;i>=0;i--){ if(par[a][i]!=par[b][i]){ a=par[a][i]; b=par[b][i]; } } return par[a][0]; } int dis[maxn],vis[maxn]; struct node{ int u,dis; node(int _u,int _d){ u=_u; dis=_d; } }; auto cmp=[](node&a,node&b){ return a.dis>b.dis; }; priority_queue<node,vector<node>,decltype(cmp)> que(cmp); void dij(int s){ memset(dis,inf,sizeof(dis)); que.push(node(s,0)); dis[s]=0; while(!que.empty()){ node a=que.top();que.pop(); int u=a.u; if(vis[u]) continue; vis[u]=1; for(int i=head[u];~i;i=e[i].next){ int v=e[i].to; if(dis[v]>dis[u]+e[i].w){ dis[v]=dis[u]+e[i].w; que.push(node(v,dis[v])); } } } } signed main() { #ifndef ONLINE_JUDGE freopen("in.txt","r",stdin); freopen("out.txt","w",stdout); #endif //=========================================================== rep(i,1,maxn-1) lg[i]=lg[i-1]+((1<<lg[i-1])==i); memset(head,-1,sizeof(head)); read(n); rep(i,2,n){ int a,b;read(a),read(b); add(a,b,1),add(b,a,1); } dfs(1,-1,0); int u,v;read(u),read(v); add(u,v,0);add(v,u,0); dij(u); int q;read(q); //rep(i,1,n) cerr<<dis[i]<<" "; //cerr<<endl; while(q--){ int x,y;read(x),read(y); int res1=depth[x]+depth[y]-2*depth[lca(x,y)]; int res2=dis[x]+dis[y]; write(min(res1,res2)),putchar('\n'); } //=========================================================== return 0; }