思路:在树上找2点的最短距离,很容易想到LCA,那么我们在树上找a,b的最短距离,因为缆车的存在所以有3种找法,一种是直接从a节点到b,第二种是a先到缆车x加上b到缆车y的距离,第三种就是a到缆车y加上b到缆车x的距离,我们取个最小值(~ 用C++11会TLE ~)
#include <cstdio> #include <cstring> #include <algorithm> #include <set> #include<iostream> #include<vector> #include<queue> //#include<bits/stdc++.h> using namespace std; typedef long long ll; #define SIS std::ios::sync_with_stdio(false) #define space putchar(' ') #define enter putchar('\n') #define lson root<<1 #define rson root<<1|1 typedef pair<int,int> PII; const int mod=1e9+7; const int N=2e6+10; const int M=1e5+10; const int inf=0x7f7f7f7f; const int maxx=2e5+7; ll gcd(ll a,ll b) { return b==0?a:gcd(b,a%b); } ll lcm(ll a,ll b) { return a*(b/gcd(a,b)); } template <class T> void read(T &x) { char c; bool op = 0; while(c = getchar(), c < '0' || c > '9') if(c == '-') op = 1; x = c - '0'; while(c = getchar(), c >= '0' && c <= '9') x = x * 10 + c - '0'; if(op) x = -x; } template <class T> void write(T x) { if(x < 0) x = -x, putchar('-'); if(x >= 10) write(x / 10); putchar('0' + x % 10); } ll qsm(int a,int b,int p) { ll res=1%p; while(b) { if(b&1) res=res*a%p; a=1ll*a*a%p; b>>=1; } return res; } const int MX=3e5+10; struct node { int ch,dis; }; vector<node> G[MX*2]; int rdis[MX]; int fa[22][MX],dep[MX]; void dfs(int u,int f,int dis) { fa[0][u]=f; rdis[u]=dis; dep[u]=dep[f]+1; for(int i=1;i<=20;i++) { fa[i][u]=fa[i-1][fa[i-1][u]]; } int len=G[u].size(); for(int i=0;i<len;i++) { int v=G[u][i].ch; if(v==f)continue; fa[0][v]=u; dfs(v,u,dis+G[u][i].dis); } } void init(int n) { memset(fa,0,sizeof(fa)); dep[0]=0; dfs(1,0,0); } int lca(int u,int v) { if(dep[u]>dep[v]){ swap(u,v); } for(int k=0;k<=20;k++) { if((dep[v]-dep[u])>>k&1){ v=fa[k][v]; } } if(u==v)return u; for(int k=20;k>=0;k--){ if(fa[k][u]!=fa[k][v]){ u=fa[k][u]; v=fa[k][v]; } } return fa[0][u]; } int getdis(int u,int v) { return rdis[u]+rdis[v]-rdis[lca(u,v)]*2; } int main() { int t,n,m; read(n); //for(int i=0;i<=n;i++) G[i].clear(); int i,j,k; for(int c=1;c<=n-1;c++){ read(i),read(j); G[i].push_back({j,1}); G[j].push_back({i,1}); } int x,y; read(x),read(y); //G[x].push_back({y,0}); //G[y].push_back({x,0}); init(n); int a,b; read(m); while(m--){ read(a),read(b); int d1,d2,d3; d1=getdis(a,b); d2=getdis(a,x)+getdis(b,y); d1=min(d1,d2); d2=getdis(a,y)+getdis(b,x); d1=min(d1,d2); printf("%d\n",d1); } return 0; } /* 1 3 2 1 1 2 3 4 1 0 1 2 3 2 1 2 3 4 */