题意:
有一颗n个节点的树,现在有m个询问,每个询问给与三个节点,求这三个节点在哪个节点集合时总距离最小?输出集合节点和总距离。

思路:
我们知道在树上求任意两个节点可以用lca求取,求三个节点怎么转换成求两个节点呢,我们画一颗树然后任找三个点可以发现两两之间的lca的值要么三个相等,此时该点为集合点,要么就是有两点相等,一个点不同,此时该不同点为集合点时总距离最短,可参考https://blog.nowcoder.net/n/0424b84aa132430f81321b11fe63e804 这位大佬的。集合点知道了,直接用lca分别三点求出与集合点的距离再加起来。

代码:

#include<bits/stdc++.h>
#define ll long long
using namespace std;

int read()
{
    int x=0, g=1;
    char c=getchar();
    while(c<'0'||c>'9')
    {
        if(c=='-')
        {
            g=-1;
        }
        c=getchar();
    }
    while(c<='9'&&c>='0')
    {
        x=(x<<3)+(x<<1)+(c^48);
        c=getchar();
    }
    return x*g;
}

vector<int> g[500005];
int dep[1000005], vs[1000005], ji=0, id[1000005], st[22][1000005];

void dfs(int v,int d,int f)
{
    dep[v]=d;
    id[v]=ji;
    vs[ji]=v;
    ji++;
    for(int i=0; i<g[v].size(); i++)
    {
        if(f!=g[v][i])
        {
            dfs(g[v][i],d+1,v);
            dep[v]=d;
            vs[ji]=v;
            ji++;
        }
    }
}

int lca(int u,int v)
{
    int x=id[u], y=id[v];
    if(x>y)
    {
        swap(x,y);
    }
    int d=y-x;
    if(d==0)
    {
        return u;
    }
    int mi, k=(int)log2(d);
    if(dep[st[k][x]]<dep[st[k][y-(1<<k)+1]])
    {
        mi=st[k][x];
    }
    else
    {
        mi=st[k][y-(1<<k)+1];
    }
    return mi;
}

int main()
{
    int n, m;
    n=read();
    m=read();
    for(int i=0; i<n-1; i++)
    {
        int u=read(), v=read();
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1,0,-1);
    for(int i=0;i<22;i++)
    {
        if((1<<i)>ji)
        {
            break;
        }
        if(i==0)
        {
            for(int j=0;j<ji;j++)
            {
                st[i][j]=vs[j];
            }
        }
        else
        {
            for(int j=0;j<ji-(1<<i)+1;j++)
            {
                if(dep[st[i-1][j]]<dep[st[i-1][j+(1<<(i-1))]])
                {
                    st[i][j]=st[i-1][j];
                }
                else
                {
                    st[i][j]=st[i-1][j+(1<<(i-1))];
                }
            }
        }
    }
    while(m--)
    {
        int a=read(), b=read(), c=read(), v;
        int pa=lca(a,b), pb=lca(b,c), pc=lca(a,c);
        if(pa==pb&&pb==pc)
        {
            v=pa;
        }
        else
        {
            v=pa^pb^pc;
        }
        int sum=dep[a]+dep[v]-dep[lca(a,v)]*2;
        sum+=dep[b]+dep[v]-dep[lca(b,v)]*2;
        sum+=dep[c]+dep[v]-dep[lca(c,v)]*2;
        printf("%d %d\n",v,sum);
    }
    return 0;
}