这题太变态了吧。
分析
我们预处理出每个帮派的lca节点,当帮派合并的时候,我们就可以求各个帮派的lca的节点。
假设首都节点是u,各个帮派的lca是pos,分两种情况
当lca(u,pos) != pos的时候,说明u不在pos的子树下,答案就是dis(u,pos)。
当lca(u,pos) == pos的时候,说明u在pos的子树下,首先我们明确一点,对于帮派控制的节点,它与u的lca一定是pos的子节点,我们对每个联盟的帮派求离u最近的dfs序的节点,答案就是dis(u,lca(u,v))。
求最近的dfs序可以利用二分。
#include <bits/stdc++.h> using namespace std; #define mem(a,b) memset(a,b,sizeof(a)) #define pii pair<int,int> #define int long long const int inf = 0x3f3f3f3f; const int maxn = 501110; const int M = 1e9+7; int n,k,q,t = 20; int read() { int x=0,f=1; char c=getchar(); while(c<'0'||c>'9') {if(c=='-') f=-1;c=getchar();} while(c>='0'&&c<='9') x=(x<<1)+(x<<3)+c-'0',c=getchar(); return f*x; } void print(int x) { if(x < 0) {putchar('-');x = -x;} if(x/10) print(x/10); putchar(x%10+'0'); } int head[maxn],to[maxn*2],Next[maxn*2],cnt = 2; void add(int u,int v) { to[cnt] = v;Next[cnt] = head[u];head[u] = cnt;cnt++; } vector<int> v[maxn]; //帮派 int fa[maxn][25],d[maxn],top[maxn]; int dfn[maxn],tim; //dfs序列,第几个访问到的 void dfs(int u,int pre) { dfn[u] = ++tim; d[u] = d[pre]+1; fa[u][0] = pre; for(int i = 1; (1ll<<i) <= d[u]; i++) { fa[u][i] = fa[fa[u][i-1]][i-1]; } for(int i = head[u]; i ; i = Next[i]) { int v = to[i]; if(v == pre) continue; dfs(v,u); } } int lca(int x,int y) { if(d[x] > d[y]) swap(x,y); for(int i = t; i >= 0; i--) { if(d[fa[y][i]] >= d[x]) y = fa[y][i]; } if(x == y) return x; for(int i = t; i >= 0; i--) { if(fa[y][i] != fa[x][i]) { y = fa[y][i]; x = fa[x][i]; } } return fa[x][0]; } int dist(int x,int y) { return d[x]+d[y]-2*d[lca(x,y)]; } bool cmp(int x,int y) { return dfn[x] < dfn[y]; } signed main() { n = read(); for(int i = 1,x,y; i < n; i++) { x = read(), y = read(); add(x,y); add(y,x); } dfs(1,0); k = read(); for(int i = 1,sz; i <= k; i++) { sz = read(); for(int j = 1,x; j <= sz; j++) { x = read(); v[i].push_back(x); if(top[i] == 0) top[i] = x; else top[i] = lca(top[i],x); } sort(v[i].begin(),v[i].end(),cmp); } q = read(); vector<int> vt; for(int i = 1,u,sz; i <= q; i++) { u = read(),sz = read(); int pos = 0; vt.clear(); for(int j = 1,x; j <= sz; j++) { x = read(); vt.push_back(x); if(pos == 0) pos = top[x]; else pos = lca(pos,top[x]); } int ff = lca(u,pos); int ans = inf; if(ff != pos) //不在子树 { ans = dist(u,pos); } else //在子树里面 { for(auto node : vt) { int l = 0,r = v[node].size()-1; pos = r; while(l <= r) { int mid = (l+r)/2; if(dfn[v[node][mid]] >= dfn[u]) { pos = mid; r = mid-1; } else l = mid+1; } ans = min(ans,dist(u,lca(u,v[node][pos]))); if(pos) ans = min(ans,dist(u,lca(u,v[node][pos-1]))); } } print(ans);putchar('\n'); } return 0; }