题意
有一个 个结点的树,每条边有边权,结点度数就是与之相连的边数量。对于 ,删掉一些边使每个结点的度数不大于 ,求出删掉的边的权值和最小值。
分析
简化问题
求解 时的问题。考虑 ,定义 为,结点 为根子树满足 ,且与父亲的边是否断掉的最小代价。那么对于一个度数大于 的节点,如果不断父亲的边,我们就要选择 个儿子断掉,如果断父亲的边则选择 个儿子。那么贪心的考虑哪些儿子是最优的。我们 的儿子的值放在一个平衡树中,对于小于 的,直接加 。然后贪心选取前几个就好了。时间复杂度为 。
原问题
考虑由小到大枚举 。如果一个节点的度数已经小于或者等于 ,那么这个节点再也不会在后面的决策中考虑到了。那么只需要在去掉这个节点的森林中每个联通块做一次 。那么每个节点考虑的次数为 ,而 ,对于一个联通块的 一定不会从一个节点转移到一个度数小于等于 的节点,那么可以对边和节点度数排序降低复杂度。所以加上平衡树的复杂度,总的复杂度为 。
代码
#include<bits/stdc++.h> using namespace std; #define LL long long #define P pair<LL,LL> #define It multiset<LL>::iterator const int N = 250005; LL ans,sum[N],f[N][2]; vector<P> G[N<<1]; P d[N]; multiset<LL> H[N]; void insert(int x,LL val) { H[x].insert(val); } void erase(LL x,LL val) { It it = H[x].lower_bound(val);H[x].erase(it); } bool vis[N]; LL Max,du[N],n; void die(LL x) { for(LL i = 0;i < G[x].size();i++) { LL y = G[x][i].second; if(du[y] <= Max) break; insert(y,G[x][i].first);sum[y] += G[x][i].first; } } bool cmp(P a,P b){ return du[a.second] > du[b.second]; } vector<LL> ins,del; void dfs(int x,int fa) { vis[x] = 1; LL m = du[x] - Max; while(H[x].size() && H[x].size() > m) { It it = H[x].end();--it;sum[x]-= *it;H[x].erase(it); } for(LL i = 0;i < G[x].size();i++) { LL y = G[x][i].second; if(y == fa) continue; if(du[y] <= Max) break; dfs(y,x); } ins.clear();del.clear(); LL Ans = 0; for(int i = 0;i < G[x].size();i++) { int y = G[x][i].second; if(y == fa) continue; if(du[y] <= Max) break; LL val = f[y][1] + G[x][i].first - f[y][0]; if(val <= 0) {Ans += f[y][1] + G[x][i].first;m--;continue;} insert(x,val);del.push_back(val); Ans += f[y][0];sum[x] += val; } while(H[x].size() && H[x].size() > m) { It it = H[x].end();--it; sum[x] -= *it,ins.push_back(*it),H[x].erase(it); } f[x][0] = Ans + sum[x]; while(H[x].size() && H[x].size() > m-1) { It it = H[x].end();--it; sum[x] -= *it,ins.push_back(*it),H[x].erase(it); } f[x][1] = Ans + sum[x]; for(LL i = 0;i < ins.size();i++) insert(x,ins[i]),sum[x] += ins[i]; for(LL i = 0;i < del.size();i++) erase(x,del[i]),sum[x] -= del[i]; } int main() { scanf("%lld",&n); for(int i = 1,a,b,c;i < n;i++) { scanf("%lld%lld%lld",&a,&b,&c);ans += c; G[a].push_back(P(c,b)); G[b].push_back(P(c,a)); du[a]++;du[b]++; } for(int i = 1;i <= n;i++) { d[i] = P(du[i],i); sort(G[i].begin(),G[i].end(),cmp); } sort(d+1,d+1+n); printf("%lld ",ans); LL pos = 1; for(Max = 1;Max < n;Max++) { while(pos <= n && d[pos].first <= Max) die(d[pos].second),++pos; ans = 0;memset(vis,0,sizeof(vis)); for(int j = pos;j <= n;j++) { if(vis[d[j].second]) continue; dfs(d[j].second,0); ans += f[d[j].second][0]; } printf("%lld ",ans); } printf("\n"); return 0; }