对于每次询问,我们建一棵“虚树”
这棵虚树只包括询问点以及相应的lca
这样在虚树上dp复杂度为
总复杂度为
代码如下:
#include <bits/stdc++.h> typedef long long ll; const int N = 250010; using namespace std; int n, q; struct edge { int e, w, nxt; } E[N << 1]; int cc, h[N]; void add(int u,int v,int w) { E[cc].e = v; E[cc].w = w; E[cc].nxt = h[u]; h[u] = cc; ++cc; } int dfn[N], id, fa[N][20], dep[N], mn[N][20]; void dfs(int u, int pre) { dfn[u] = ++ id; fa[u][0] = pre; dep[u] = dep[pre] + 1; for(int i = 1; i <= 19; ++i) { fa[u][i] = fa[fa[u][i-1]][i-1]; mn[u][i] = min(mn[u][i-1], mn[fa[u][i-1]][i-1]); } for(int i = h[u]; ~i; i = E[i].nxt) { int v = E[i].e, w = E[i].w; if(v == pre) continue; mn[v][0] = w; dfs(v, u); } } int lca(int u, int v) { if(dep[u] > dep[v]) swap(u, v); for(int i = 19; i >= 0; --i) if(dep[fa[v][i]] >= dep[u]) v = fa[v][i]; if(u == v) return v; for(int i = 19; i >= 0; --i) if(fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i]; return fa[u][0]; } int calmn(int u, int v) { int ans = INT_MAX; if(dep[u] > dep[v]) swap(u, v); for(int i = 19; i >= 0; --i) if(dep[fa[v][i]] >= dep[u]) ans = min(ans, mn[v][i]), v = fa[v][i]; return ans; } bool cmp(int u, int v) { return dfn[u] < dfn[v]; } int a[N], mk[N], stc[N], top; void build(int a[], int m) { sort(a+1, a+1+m, cmp); stc[top = 1] = 1; cc = 0; h[1] = -1; for(int p, i = 1; i <= m; ++i) { if(a[i] != 1) { p = lca(stc[top], a[i]); if(p != stc[top]) { while(dfn[p] < dfn[stc[top-1]]) { int w = calmn(stc[top-1],stc[top]); add(stc[top-1],stc[top],w); -- top; } if(dfn[p] > dfn[stc[top-1]]) { int w = calmn(p,stc[top]); h[p] = -1; add(p, stc[top], w), stc[top] = p; } else { int w = calmn(p,stc[top]); add(p, stc[top--], w); } } h[a[i]] = -1, stc[++top] = a[i]; } } for(int w, i = 1; i < top; ++i) { w = calmn(stc[i], stc[i+1]); add(stc[i],stc[i+1],w); } } ll dp[N]; void Dp(int u) { dp[u] = 0; for(int i = h[u]; ~i; i = E[i].nxt) { Dp(E[i].e); if(mk[E[i].e]) dp[u] += E[i].w; else dp[u] += min(1ll*E[i].w, dp[E[i].e]); } } int main() { scanf("%d",&n); memset(h,-1,sizeof(h)); for(int u, v, w, i = 1; i < n; ++i) { scanf("%d%d%d",&u,&v,&w); add(u,v,w), add(v,u,w); } dfs(1,0); scanf("%d",&q); while(q--) { int m; scanf("%d",&m); for(int i = 1; i <= m; ++i) scanf("%d",&a[i]), mk[a[i]] = 1; build(a, m); Dp(1); printf("%lld\n", dp[1]); for(int i = 1; i <= m; ++i) mk[a[i]] = 0; } }