Solution
题目需要我们求解收益的最大值,那么我们就要拆分每一条边都来考虑一下,因为给出的是一棵树,所以对于每条边他会被计算的次数就是在它下方有几个黑节点乘上它上方有几个黑节点,再加上它下方有几个白节点乘上它上方有几个白节点。这样我们通过枚举根节点它的子树会存在几个黑节点进行动态规划处理。
我们使用代表节点以及它的全部子树中选择个黑点的最大收益。那么就变成了我一共有个大小的背包,最后要填满这个背包,每个子树都可以选择一定大小的物品进行填入,就转变成了一个树形背包了。
那么如果我们枚举某个节点它和它的子树中一共存在个节点,那么我们再枚举它的第一棵子树假设是节点,并且我们可以枚举中会包含全部可能的黑点数量,那么这样就可以对这条边全部的出现次数进行统计,并且更新数组。
那么这里我们给出转移方程:
上面的式子中代表这颗子树包含节点自己的节点个数。计算的就是这条边当前状态带来的贡献。
并且对于每个节点而言,我们的是允许等于的,那么这样的话,如果你直接倒序写子树的话,那么它每次都会有一次这样的计算,那就是可以发现这个转移是一定会发生的。那么在之前,会用,接下来又用这个子树没节点更新了,所以就会出现重复计算,对着之前更新的重新又加。
所以存在两种办法解决,第一种是枚举子树的时候先处理掉子树中全部都是白色节点的情况。第二种方法是直接使用正序的枚举去更新。
#include <bits/stdc++.h> using namespace std; #define js ios::sync_with_stdio(false);cin.tie(0); cout.tie(0) #define all(__vv__) (__vv__).begin(), (__vv__).end() #define endl "\n" #define pai pair<int, int> #define ms(__x__,__val__) memset(__x__, __val__, sizeof(__x__)) #define rep(i, sta, en) for(int i=sta; i<=en; ++i) #define repp(i, sta, en) for(int i=sta; i>=en; --i) typedef long long ll; typedef unsigned long long ull; typedef long double ld; inline ll read() { ll s = 0, w = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') w = -1; for (; isdigit(ch); ch = getchar()) s = (s << 1) + (s << 3) + (ch ^ 48); return s * w; } inline void print(ll x, int op = 10) { if (!x) { putchar('0'); if (op) putchar(op); return; } char F[40]; ll tmp = x > 0 ? x : -x; if (x < 0)putchar('-'); int cnt = 0; while (tmp > 0) { F[cnt++] = tmp % 10 + '0'; tmp /= 10; } while (cnt > 0)putchar(F[--cnt]); if (op) putchar(op); } inline ll gcd(ll x, ll y) { return y ? gcd(y, x % y) : x; } ll qpow(ll a, ll b) { ll ans = 1; while (b) { if (b & 1) ans *= a; b >>= 1; a *= a; } return ans; } ll qpow(ll a, ll b, ll mod) { ll ans = 1; while (b) { if (b & 1)(ans *= a) %= mod; b >>= 1; (a *= a) %= mod; }return ans % mod; } const int dir[][2] = { {0,1},{1,0},{0,-1},{-1,0},{1,1},{1,-1},{-1,1},{-1,-1} }; const int MOD = 1e9 + 7; const int INF = 0x3f3f3f3f; struct Node { ll val; int id; bool operator < (const Node& opt) const { return val < opt.val; } }; const int N = 2e3 + 7; int n, m, sz[N]; vector<pai> edge[N]; ll f[N][N >> 1]; void dfs(int u, int fa) { sz[u] = 1; f[u][0] = f[u][1] = 0; for (auto& it : edge[u]) { int v = it.first, w = it.second; if (v == fa) continue; dfs(v, u); sz[u] += sz[v]; repp(i, min(sz[u], m), 0) { if (f[u][i] != -1) f[u][i] += f[v][0] + 1ll * sz[v] * (n - m - sz[v]) * w; // 这个子树全部选白点,必须先处理 repp(j, min(i, sz[v]), 1) { // 注意边界 if (f[u][i - j] == -1) continue; // 前面的子树一共选i-j个 ll tmp = 1ll * (j * (m - j) + (sz[v] - j) * (n - sz[v] - (m - j))) * w; // 子树黑*上面黑 + 子树白*上面白 f[u][i] = max(f[u][i], f[u][i - j] + f[v][j] + tmp); } } } } void solve() { ms(f, -1); n = read(), m = read(); if (n - m < m) m = n - m; rep(i, 2, n) { int u = read(), v = read(), w = read(); edge[u].push_back({ v,w }); edge[v].push_back({ u,w }); } dfs(1, 1); print(f[1][m]); } int main() { //int T = read(); while (T--) solve(); return 0; }