http://acm.hdu.edu.cn/showproblem.php?pid=4547

C++版本一

题解:

LCA树上倍增

注意:CD 向下走可以一步到底

/*
*@Author:   STZG
*@Language: C++
*/
#include <bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<string>
#include<vector>
#include<bitset>
#include<queue>
#include<deque>
#include<stack>
#include<cmath>
#include<list>
#include<map>
#include<set>
//#define DEBUG
#define RI register int
#define endl "\n"
using namespace std;
typedef long long ll;
//typedef __int128 lll;
const int N=100000+10;
const int M=100000+10;
const int MOD=1e9+7;
const double PI = acos(-1.0);
const double EXP = 1E-8;
const int INF = 0x3f3f3f3f;
int t,n,m,k,p,l,r,u,v,c;
int ans,cnt,flag,temp,tot,sum,num;
int pre[N];
bool vis[N];
int dep[N];
int dp[N][21];
string str,str1;
struct node{
    int v,w;
}x;
map<string,int>mp;
vector<int>G[N];
int find(int x){return pre[x]==x?x:pre[x]=find(pre[x]);}
void marge(int u,int v){
    int tu=find(u);
    int tv=find(v);
    if(tu!=tv){
        pre[tu]=tv;
    }
}
void dfs(int u){
    vis[u]=1;
    for(int i=0,j=G[u].size();i<j;i++){
            int v=G[u][i];
        if(!vis[v]){
            dep[v]=dep[u]+1;
            dp[v][0]=u;
            dfs(v);
        }
    }
}
void init(){
    for(int i=1;i<=n;i++){
       G[i].clear();
       pre[i]=i;
    }
    memset(vis,0,sizeof(vis));
    memset(dep,0,sizeof(dep));
    memset(dp,0,sizeof(dp));
    cnt=0,num=0;
    tot=0;
    mp.clear();
}
int LCA(int x,int y){
    if(dep[x]<dep[y])
        swap(x,y);//cout<<x<<" "<<dp[1][0]<<endl;
    while(dep[x]>dep[y])
        x=dp[x][(int)log2(dep[x]-dep[y])];
    if(x==y)
        return x;

    for(int i=log2(dep[x]);i>=0;i--){
        if(dp[x][i]!=dp[y][i])
            x=dp[x][i],y=dp[y][i];
    }
    return dp[x][0];
}
int main()
{
#ifdef DEBUG
	freopen("input.in", "r", stdin);
	//freopen("output.out", "w", stdout);
#endif
    //ios::sync_with_stdio(false);
    //cin.tie(0);
    //cout.tie(0);
    scanf("%d",&t);
    while(t--){
        scanf("%d%d",&n,&m);
        init();
        for(int i=1;i<n;i++){
            cin>>str>>str1;
            if(!mp[str])
                mp[str]=++tot;
            if(!mp[str1])
                mp[str1]=++tot;
            u=mp[str];
            v=mp[str1];
            G[u].push_back(v);
            G[v].push_back(u);
            marge(u,v);
        }
        //cout<<"1"<<endl;
        for(int i=1;i<=n;i++){
            if(pre[i]==i){
                dfs(i);
            }
        }
        //cout<<"2"<<endl;
        for(int i=0;i<20;i++){
            for(int j=1;j<=n;j++){
                dp[j][i+1]=dp[dp[j][i]][i];
            }
        }
        for(int i=1;i<=m;i++){
            cin>>str>>str1;
            u=mp[str];
            v=mp[str1];
            int lca=LCA(u,v);
            if(lca!=v){
                cout<<dep[u]-dep[lca]+1<<endl;
            }else{
                cout<<dep[u]-dep[lca]<<endl;
            }
        }
    }

#ifdef DEBUG
	printf("Time cost : %lf s\n",(double)clock()/CLOCKS_PER_SEC);
#endif
    //cout << "Hello world!" << endl;
    return 0;
}

C++版本二

题解:LCA转RMQ

卡着内存过的

/*
*@Author:   STZG
*@Language: C++
*/
#include <bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<string>
#include<vector>
#include<bitset>
#include<queue>
#include<deque>
#include<stack>
#include<cmath>
#include<list>
#include<map>
#include<set>
//#define DEBUG
#define RI register int
#define endl "\n"
using namespace std;
typedef long long ll;
//typedef __int128 lll;
const int N=100000+10;
const int M=100000+10;
const int MOD=1e9+7;
const double PI = acos(-1.0);
const double EXP = 1E-8;
const int INF = 0x3f3f3f3f;
int t,n,m,k,p,l,r,u,v,c;
int ans,cnt,flag,temp,tot,sum,num;
int pre[N];
bool vis[N];
int dep[N<<1];
int dp[N<<1][18];
int que[N<<1];
int frist[N];
string str,str1;
map<string,int>mp;
vector<int>G[N];
int find(int x){return pre[x]==x?x:pre[x]=find(pre[x]);}
void marge(int u,int v){
    int tu=find(u);
    int tv=find(v);
    if(tu!=tv){
        pre[tu]=tv;
    }
}
void dfs(int u,int deep){
    vis[u]=1;
    que[++num]=u;
    frist[u]=num;
    dep[num]=deep;
    for(int i=0,j=G[u].size();i<j;i++){
            int v=G[u][i];
        if(!vis[v]){
            dfs(v,deep+1);
            que[++num]=u;
            dep[num]=deep;
        }
    }
}
void init(){
    for(int i=1;i<=n;i++){
       G[i].clear();
       pre[i]=i;
    }
    memset(vis,0,sizeof(vis));
    memset(dep,0,sizeof(dep));
    memset(dp,0,sizeof(dp));
    memset(frist,0,sizeof(frist));
    memset(que,0,sizeof(que));
    cnt=0,num=0;
    tot=0;
    mp.clear();
}
void ST(int n){
    for(int i=1;i<=n;i++){
        dp[i][0]=i;
    }
    for(int i=0;(1<<i)<n;i++){
        for(int j=1;j+(1<<i)<=n;j++){
            int a=dp[j][i],b=dp[j+(1<<i)][i];
            dp[j][i+1]=dep[a]<dep[b]?a:b;
        }
    }
}
int RMQ(int l,int r){
    int k=log2(r-l+1);
    int a=dp[l][k],b=dp[r-(1<<k)+1][k];
    return dep[a]<dep[b]?a:b;
}

int LCA(int u,int v){
    int x=frist[u],y=frist[v];
    if(x>y)swap(x,y);
    int res=RMQ(x,y);
    return que[res];
}
int main()
{
#ifdef DEBUG
	freopen("input.in", "r", stdin);
	//freopen("output.out", "w", stdout);
#endif
    //ios::sync_with_stdio(false);
    //cin.tie(0);
    //cout.tie(0);
    scanf("%d",&t);
    while(t--){
        scanf("%d%d",&n,&m);
        init();
        for(int i=1;i<n;i++){
            cin>>str>>str1;
            if(!mp[str])
                mp[str]=++tot;
            if(!mp[str1])
                mp[str1]=++tot;
            u=mp[str];
            v=mp[str1];
            //G[u].push_back(v);
            G[v].push_back(u);
            marge(u,v);
        }

        //cout<<"1"<<endl;
        for(int i=1;i<=n;i++){
            if(pre[i]==i){
                dfs(i,0);
            }
        }
        //cout<<"2"<<endl;
        ST(2*n-1);
        for(int i=1;i<=m;i++){
            cin>>str>>str1;
            u=mp[str];
            v=mp[str1];
            int lca=LCA(u,v);
            int x=frist[u],y=frist[v],z=frist[lca];
            //cout<<u<<v<<lca<<endl;
            //cout<<x<<y<<z<<endl;
            //cout<<dep[x]<<dep[y]<<dep[z]<<endl;
            //cout<<que[x]<<que[y]<<que[z]<<endl;
            if(lca!=v){
                cout<<dep[x]-dep[z]+1<<endl;
            }else{
                cout<<dep[x]-dep[z]<<endl;
            }
        }
    }

#ifdef DEBUG
	printf("Time cost : %lf s\n",(double)clock()/CLOCKS_PER_SEC);
#endif
    //cout << "Hello world!" << endl;
    return 0;
}

C++版本三

题解:Tarjan算法

#include<bits/stdc++.h>
using namespace std;
#define MAXN 100010
#define MAXL 45

struct Edge
{
        int to,nxt;
} e[MAXN];

int i,T,tot,id,n,m,root;
int head[MAXN],dep[MAXN],f[MAXN],lca[MAXN],fa[MAXN];
string a,b;
string x[MAXN],y[MAXN];
bool visited[MAXN];
map<string,int> mp;
vector< pair<int,int> > query[MAXN];

inline void add(int u,int v)
{
        tot++;
        e[tot] = (Edge){v,head[u]};
        head[u] = tot;
}
inline void init(int u)
{
        int i,v;
        for (i = head[u]; i; i = e[i].nxt)
        {
                v = e[i].to;
                dep[v] = dep[u] + 1;
                init(v);
        }        
}
inline int find(int x)
{
        if (f[x] == x) return x;
        return f[x] = find(f[x]);
}
inline void tarjan(int u)
{
        int i,v,pos;
        visited[u] = true;
        f[u] = u;
        for (i = 0; i < query[u].size(); i++)
        {
              v = query[u][i].first;
              pos = query[u][i].second;
              if (visited[v]) lca[pos] = find(v);
        }
        for (i = head[u]; i; i = e[i].nxt)
        {
                v = e[i].to;
                tarjan(v);
                f[v] = u;
        }
}

int main() 
{
        
        ios :: sync_with_stdio(0);
        cin.tie(0);
        
        cin >> T;
        while (T--)
        {
                cin >> n >> m;
                tot = id = 0;
                mp.clear();
                for (i = 1; i <= n; i++)
                {
                        head[i] = 0;
                        fa[i] = 0;
                        visited[i] = false;
                        query[i].clear();
                }
                for (i = 1; i < n; i++) 
                {
                        cin >> a >> b;
                        if (!mp[a]) mp[a] = ++id;
                        if (!mp[b]) mp[b] = ++id;    
                        add(mp[b],mp[a]);
                        fa[mp[a]] = mp[b];
                }    
                for (i = 1; i <= n; i++) 
                {
                        if (!fa[i])
                                root = i;
                }
                dep[root] = 0;
                init(root);
                for (i = 1; i <= m; i++) 
                {
                        cin >> x[i] >> y[i];
                        query[mp[x[i]]].push_back(make_pair(mp[y[i]],i));
                        query[mp[y[i]]].push_back(make_pair(mp[x[i]],i));
                }
                tarjan(root);
                for (i = 1; i <= m; i++) 
                {
                        if (x[i] == y[i]) printf("%d\n",0);
                        else if (lca[i] == mp[x[i]]) printf("%d\n",1);
                        else if (lca[i] == mp[y[i]]) printf("%d\n",dep[mp[x[i]]]-dep[mp[y[i]]]);
                        else printf("%d\n",dep[mp[x[i]]]-dep[lca[i]]+1);
                }
        }
        
        return 0;
    
}

C++版本四

题解:Tarjan算法

/*
*@Author:   STZG
*@Language: C++
*/
#include <bits/stdc++.h>
#include<iostream>
#include<algorithm>
#include<cstdlib>
#include<cstring>
#include<cstdio>
#include<string>
#include<vector>
#include<bitset>
#include<queue>
#include<deque>
#include<stack>
#include<cmath>
#include<list>
#include<map>
#include<set>
//#define DEBUG
#define RI register int
#define endl "\n"
using namespace std;
typedef long long ll;
//typedef __int128 lll;
const int N=200000+10;
const int M=100000+10;
const int MOD=1e9+7;
const double PI = acos(-1.0);
const double EXP = 1E-8;
const int INF = 0x3f3f3f3f;
int t,n,m,k,q,w,v,u;
int ans,cnt,flag,temp,sum,tot;
int pre[N];
bool vis[N];
int dep[N];
int ANS[N];
struct query{
    int u,v,id;
}e[N],y;
string str,str1;
map<string,int>mp;
vector<int>G[N];
vector<query>Q[N];
int find(int x){return pre[x]==x?x:pre[x]=find(pre[x]);}
void marge(int u,int v){
    int tu=find(u);
    int tv=find(v);
    if(tu!=tv){
        pre[tu]=tv;
    }
}
void Tarjan(int u){//marge和find为并查集合并函数和查找函数
    //cout<<u<<endl;
    vis[u]=1;
    for(int i=0,j=G[u].size();i<j;i++) {   //访问所有u子节点v
        int v=G[u][i];
        if(vis[v])
            continue;
        dep[v]=dep[u]+1;
        Tarjan(v);        //继续往下遍历
        marge(v,u);   //合并v到u上
        //cout<<dep[v]<<endl;
    }
    for(int i=0,j=Q[u].size();i<j;i++){ //访问所有和u有询问关系的e
        int e=Q[u][i].v;
        int id=Q[u][i].id;
        int ID=(id+1)/2;
        if(vis[e]){
            int lca=find(e);
            //cout<<lca<<endl;
            if(id%2){
                if(lca!=e){
                    ANS[ID]=dep[u]-dep[lca]+1;
                }else{
                    ANS[ID]=dep[u]-dep[lca];
                }
            }else{
                if(lca!=u){
                    ANS[ID]=dep[e]-dep[lca]+1;
                }else{
                    ANS[ID]=dep[e]-dep[lca];
                }
            }
        }
    }
}
void init(){
    for(int i=1;i<=n;i++){
        pre[i]=i;
        G[i].clear();
        Q[i].clear();
    }
    memset(vis,0,sizeof(vis));
    memset(dep,0,sizeof(dep));
    memset(ANS,0x3f,sizeof(ANS));
    tot=0;
    mp.clear();
}
int main()
{
#ifdef DEBUG
	freopen("input.in", "r", stdin);
	//freopen("output.out", "w", stdout);
#endif
    //ios::sync_with_stdio(false);
    //cin.tie(0);
    //cout.tie(0);
    scanf("%d",&t);
    while(t--){
        scanf("%d%d",&n,&m);
        init();
        for(int i=1;i<n;i++){
            cin>>str>>str1;
            if(!mp[str])
                mp[str]=++tot;
            if(!mp[str1])
                mp[str1]=++tot;
            u=mp[str];
            v=mp[str1];
            G[u].push_back(v);
            G[v].push_back(u);
            marge(u,v);
        }
        for(int i=1;i<=m;i++){
            cin>>str>>str1;
            if(!mp[str])
                mp[str]=++tot;
            if(!mp[str1])
                mp[str1]=++tot;//n==1时
            u=mp[str];
            v=mp[str1];
            y.v=v;
            y.id=2*i-1;
            Q[u].push_back(y);
            y.v=u;
            y.id=2*i;
            Q[v].push_back(y);
        }
        int root=0;
        for(int i=1;i<=n;i++){
            if(pre[i]==i){
                //Tarjan(i);
                root=i;
                break;
            }
        }
        for(int i=1;i<=n;i++){
            pre[i]=i;
        }
        Tarjan(root);
        for(int i=1;i<=m;i++){
            cout<<ANS[i]<<endl;
        }
    }

#ifdef DEBUG
	printf("Time cost : %lf s\n",(double)clock()/CLOCKS_PER_SEC);
#endif
    //cout << "Hello world!" << endl;
    return 0;
}