树的最小支配集问题

题目大意:图片说明
思路:对于任意一个点,只有三种被覆盖的情况
1.被自己覆盖
2.被自己的子节点覆盖
3.被自己的父节点覆盖
故很容易推出状态表示
dp[i][j]:以i为根的子树的全部节点被覆盖且i节点的覆盖状态为j的所有方案的节点最小值
(其中j = 0表示被自己覆盖,j = 1表示被儿子覆盖, j = 2表示被父亲覆盖)
转移方程
1.i被自己覆盖:
此时与i直接相连的子树可以来自三种状态的任意一种,取最优的一种即可
dp[i][0] = ∑(min(dp[son, 0], dp[son, 1], dp[son, 2])) + 1
2.i被父亲覆盖(容易遗漏的情况)
此时与i直接相连的子树的状态不可能来自父节点覆盖,只能由从余下两种转移
dp[i][2] = ∑(min(dp(son, 0], dp[son, 1])
3.i被儿子覆盖
次情况的儿子也不可能来自父节点覆盖,此时只需要保证所有与i直接相连的儿子节点中有一个是被其本身所覆盖,若在取儿子节点的最优解的过程中全部都是由孙子节点转移过来,则必须保证有一个来自儿子,具体做法为补一个最小的差值

对比树的最小点覆盖(https://ac.nowcoder.com/acm/problem/51222)

题解链接()

AC代码

```
#include <bits stdc++.h>
using namespace std;
#define IO std::ios::sync_with_stdio(false); cin.tie(0)
#define ll long long
#define ull unsigned long long
#define SZ(x) ((int)(x).size())
#define all(x) (x).begin(), (x).end()
#define rep(i, l, r) for (int i = l; i <= r; ++i)
#define per(i, l, r) for (int i = l; i >= r; --i)
#define mset(s, _) memset(s, _, sizeof(s))
#define mcpy(s, _) memcpy(s, _, sizeof(s))
#define pb push_back
#define pii pair <int, int>
#define vi vector<int>
#define vpii vector<pii>
#define mp(a, b) make_pair(a, b)
#define pll pair <ll, ll>
#define fir first
#define sec second
#define inf 0x3f3f3f3f
inline int lowbit(int x) {return x & -x;}
template< typename T > inline void get_min(T &x, T y) {if(y < x) x = y;}
template< typename T > inline void get_max(T &x, T y) {if(x < y) x = y;}
inline int read() {
int x = 0, f = 0; char ch = getchar();
while (!isdigit(ch)) f |= ch == '-', ch = getchar();
while (isdigit(ch)) x = 10 * x + ch - '0', ch = getchar();
return f ? -x : x;
}
template<typename t=""> inline void print(T x) {
if (x < 0) putchar('-'), x = -x;
if (x >= 10) print(x / 10);
putchar(x % 10 + '0');
}
template<typename t=""> inline void print(T x, char let) {
print(x), putchar(let);
}</typename></typename></pii></int>

const int N = 2e4 + 10, mod = 1e9 + 7;
int n, m;
int h[N], nex[N], v[N], idx;
void add(int a, int b) {
v[idx] = b; nex[idx] = h[a]; h[a] = idx ++ ;
}

int f[N][5];
void dp(int u, int fa) {
bool fg = 0, mk = 0;
int gap = inf;
for(int i = h[u]; ~i; i = nex[i]) {
int e = v[i];
if(e == fa) continue;
fg = 1;
dp(e, u);
f[u][0] += min(f[e][0], min(f[e][1], f[e][2]));
f[u][2] += min(f[e][0], f[e][1]);
f[u][1] += min(f[e][0], f[e][1]);
gap = min(gap, abs(f[e][0] - f[e][1]));
if(f[e][0] <= f[e][1]) mk = 1;
}
f[u][0] += 1;
if(!mk) f[u][1] += gap;
if(!fg) {
f[u][1] = inf; f[u][0] = 1; f[u][2] = 0;
}
}

int main() {
mset(h, -1);
cin >> n;
rep(i, 1, n - 1) {
int a, b; cin >> a >> b;
add(a, b); add(b, a);
}

dp(1, -1);
cout << min(f[1][0], f[1][1]) << endl; 

return 0;

}

```</ll,></int,>