对于每次询问,我们建一棵“虚树”
这棵虚树只包括询问点以及相应的lca
这样在虚树上dp复杂度为
总复杂度为
代码如下:
#include <bits/stdc++.h>
typedef long long ll;
const int N = 250010;
using namespace std;
int n, q;
struct edge { int e, w, nxt; } E[N << 1];
int cc, h[N];
void add(int u,int v,int w) {
E[cc].e = v; E[cc].w = w; E[cc].nxt = h[u]; h[u] = cc; ++cc;
}
int dfn[N], id, fa[N][20], dep[N], mn[N][20];
void dfs(int u, int pre) {
dfn[u] = ++ id; fa[u][0] = pre; dep[u] = dep[pre] + 1;
for(int i = 1; i <= 19; ++i) {
fa[u][i] = fa[fa[u][i-1]][i-1];
mn[u][i] = min(mn[u][i-1], mn[fa[u][i-1]][i-1]);
}
for(int i = h[u]; ~i; i = E[i].nxt) {
int v = E[i].e, w = E[i].w;
if(v == pre) continue;
mn[v][0] = w; dfs(v, u);
}
}
int lca(int u, int v) {
if(dep[u] > dep[v]) swap(u, v);
for(int i = 19; i >= 0; --i)
if(dep[fa[v][i]] >= dep[u]) v = fa[v][i];
if(u == v) return v;
for(int i = 19; i >= 0; --i)
if(fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
int calmn(int u, int v) {
int ans = INT_MAX;
if(dep[u] > dep[v]) swap(u, v);
for(int i = 19; i >= 0; --i)
if(dep[fa[v][i]] >= dep[u])
ans = min(ans, mn[v][i]), v = fa[v][i];
return ans;
}
bool cmp(int u, int v) { return dfn[u] < dfn[v]; }
int a[N], mk[N], stc[N], top;
void build(int a[], int m) {
sort(a+1, a+1+m, cmp);
stc[top = 1] = 1; cc = 0; h[1] = -1;
for(int p, i = 1; i <= m; ++i) {
if(a[i] != 1) {
p = lca(stc[top], a[i]);
if(p != stc[top]) {
while(dfn[p] < dfn[stc[top-1]]) {
int w = calmn(stc[top-1],stc[top]);
add(stc[top-1],stc[top],w); -- top;
}
if(dfn[p] > dfn[stc[top-1]]) {
int w = calmn(p,stc[top]); h[p] = -1;
add(p, stc[top], w), stc[top] = p;
}
else {
int w = calmn(p,stc[top]);
add(p, stc[top--], w);
}
}
h[a[i]] = -1, stc[++top] = a[i];
}
}
for(int w, i = 1; i < top; ++i) {
w = calmn(stc[i], stc[i+1]);
add(stc[i],stc[i+1],w);
}
}
ll dp[N];
void Dp(int u) {
dp[u] = 0;
for(int i = h[u]; ~i; i = E[i].nxt) {
Dp(E[i].e);
if(mk[E[i].e]) dp[u] += E[i].w;
else dp[u] += min(1ll*E[i].w, dp[E[i].e]);
}
}
int main() {
scanf("%d",&n); memset(h,-1,sizeof(h));
for(int u, v, w, i = 1; i < n; ++i) {
scanf("%d%d%d",&u,&v,&w);
add(u,v,w), add(v,u,w);
}
dfs(1,0);
scanf("%d",&q);
while(q--) { int m;
scanf("%d",&m);
for(int i = 1; i <= m; ++i) scanf("%d",&a[i]), mk[a[i]] = 1;
build(a, m); Dp(1);
printf("%lld\n", dp[1]);
for(int i = 1; i <= m; ++i) mk[a[i]] = 0;
}
} 


京公网安备 11010502036488号