Tree Partition

题意:

给一棵树,n个点n-1条边,没给点都有点权。
要求剪掉k-1条边形成由k颗树组成的森林。
树的权值为中所有点权之和
问:怎样剪使得树的权值的最大值最小, 大小为多少?

解题:

看到“最大值最小”这类词, 先是想到了二分,二分出权值,再根据权值对树进行剪边操作,觉得可行就继续操作:

  1. 二分出权值mid;
  2. 在树中由下向上进行操作,每个节点记录节点及其儿子权值和,如果当前节点权值>mid,让当前节点由大到小减去其儿子节点直到当前节点<=mid,没见一次计数(count)加1
  3. 如果操作完后count > k - 1 说明mid小 增大mid继续操作2, 否则减小mid,继续2,直到==mid.

让当前节点由大到小减去其儿子节点直到当前节点<=mid, 我刚看开始没有想到.(所以我觉得这需要注意一下 弱弱的说一下)

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;

const int maxn = 2e5 + 7;

int h[maxn], e[maxn << 1], ne[maxn << 1], w[maxn], cnt;

void add(int u, int v){
    e[cnt] = v;
    ne[cnt] = h[u];
    h[u] = cnt ++;
}
vector<long long> vv;
int counter, n, m;
long long sf[maxn];
bool flag;
void dfs(int u, int f, long long mid){
    sf[u] = w[u];
    if (sf[u] > mid || flag){
        flag = true;  // 算小小的剪枝吧
        return;
    }
    for (int i = h[u]; ~i; i = ne[i]){
        int v = e[i];
        if (v == f) continue;
        dfs(v, u, mid);
        sf[u] += sf[v]; 
    }
    if (sf[u] > mid){
        vv.clear();
        for (int i = h[u]; ~i; i = ne[i]){
            int v = e[i];
            if (v == f) continue;
            vv.push_back(sf[v]);
        }
        sort(vv.begin(), vv.end());
        while (sf[u] > mid){
            counter ++; sf[u] -= vv.back();vv.pop_back();
        }
    }
    if (counter >= m) flag = true;
}
bool check(long long mid){
    flag = false;
    counter = 0;
    dfs(1, 1, mid);
    if (flag) return false;
    return true;
}
int main (){
    int T, tol = 1;
    scanf ("%d", &T);
    while (T -- ){
        scanf ("%d%d", &n, &m);
        for (int i = 0; i <= n; i ++ ) h[i] = -1, sf[i] = 0;
        cnt = 0;
        for (int i = 1, u, v; i < n; i ++ ){
            scanf ("%d%d", &u, &v);
            add(u, v);
            add(v, u);
        }

        for (int i = 1; i <= n; i ++ ) scanf ("%d", &w[i]);
        long long l = 0, r = 1e18;
        while (l < r){
            long long mid = l + r >> 1;
            if (check(mid)) r = mid;
            else l = mid + 1;
        }
        printf ("Case #%d: %lld\n", tol++, l);
    }
}