树形dp

给定一个N个节点的树(无根树或者有根树),我们可以选择一个节点为根进行dfs求得树的结构,作状态转移就由深到浅,由叶子节点向根进行dp

1 没有上司的舞会

1代表取当前的,0代表不取当前节点
d p [ f a ] [ 0 ] + = m a x ( d p [ s o n ] [ 1 ] , d p [ s o n ] [ 0 ] ) ; dp[fa][0] += max(dp[son][1],dp[son][0]); dp[fa][0]+=max(dp[son][1],dp[son][0]);
d p [ f a ] [ 1 ] + = d p [ s o n ] [ 0 ] ; dp[fa][1] += dp[son][0]; dp[fa][1]+=dp[son][0];

#include<bits/stdc++.h>
using namespace std;
const int maxn = 6000+10;
int dp[maxn][2];
vector<int> G[maxn];
bool deg[maxn];
int v[maxn];
void dfs(int node,int fa){
    dp[node][1] = v[node];
    for(auto c: G[node])
        if(c != fa){
            dfs(c,node);
            dp[node][0] += max(dp[c][1],dp[c][0]);
            dp[node][1] += dp[c][0];
        }
}
int main(void){
    int N;cin>>N;
    for(int i = 1;i <= N; ++i)
        cin>>v[i];
    for(int i = 1;i < N; ++i)
    {
        int a,b;cin>>a>>b;
        G[b].push_back(a);
        deg[a] = true;
    }
    int root = 1;
    for(int i = 1;i <= N; ++i)
        if(!deg[i]) root = i;
    dfs(root,-1);
    // cout<<root<<endl;
    cout<<max(dp[root][0],dp[root][1])<<endl;


    return 0;
}

2 树上背包

洛谷P2014 选课
把0号节点作为超级节点,将森林转化成树,求树上背包

#include<bits/stdc++.h>
using namespace std;
const int maxn = 300+10;
int dp[maxn][maxn];
int v[maxn];
vector<int> G[maxn];
bool deg[maxn];
int n,m;
void dfs(int node,int fa = -1){
    dp[node][0] = 0;
    for(auto c: G[node]){
        if(c == fa) continue;
        dfs(c,node);
        // siz[node] += siz[c];
        for(int i = m;i >= 0; --i)
            for(int j = i ;j >= 0; --j)
                dp[node][i] = max(dp[node][i-j]+dp[c][j],dp[node][i]);
        
    }
    // if( node != 0)
        for(int i = m;i >= 1; --i)
            dp[node][i] = v[node]+dp[node][i-1];
}
int main(void){

    cin>>n>>m;m++;
    for(int i = 1;i <= n; ++i){
        int u;
        cin>>u>>v[i];
        G[u].push_back(i);
    }
    dfs(0);
    int ans = 0;
    for(int i = 1;i <= m; ++i)
        ans = max(ans,dp[0][i]);
    cout<<ans<<endl;
    return 0;
}

3 二次扫描和换根法

POJ Accumulation Degree


#include<bits/stdc++.h>
using namespace std;

const int maxn = 3e5+10;
int d[maxn],v[maxn],f[maxn];
typedef pair<int,int> P;
vector<P> G[maxn];
int deg[maxn];
void dp(int x){
    v[x] = 1;
    d[x] = 0;
    for(auto c: G[x]){
        int y = c.first;
        int w = c.second;
        if(v[y]) continue;
        dp(y);
        if(deg[y] == 1) d[x] += w;
        else d[x] += min(w,d[y]);
    }
}
void dfs(int x){
    v[x] = 1;
    for(auto c: G[x]){
        int y = c.first;
        int w = c.second;
        if(v[y]) continue;
        if(deg[x] == 1) f[y] = d[y]+w;
        else f[y] = d[y]+min(f[x]-min(d[y],w),w);
        dfs(y);
    }
}
#define Pb push_back
int main(void){
    int T;cin>>T;
    while(T--){
        int n;cin>>n;
        for(int i = 1;i <= n; ++i)
            G[i].clear();
        memset(deg,0,sizeof(deg));
        memset(f,0,sizeof(f));
        memset(d,0,sizeof(d));
        for(int i = 1;i < n; ++i)
        {
            int u,v,w;
            scanf("%d%d%d",&u,&v,&w);
            G[u].Pb(P(v,w));
            G[v].Pb(P(u,w));
            deg[u]++;
            deg[v]++;
        }
        int root = 1;
        memset(v,0,sizeof(v));
        dp(root);
        memset(v,0,sizeof(v));
        f[root] = d[root];
        dfs(root);
        int ans = 0;
        for(int i = 1;i <= n; ++i)
            ans = max(ans,f[i]);
        cout<<ans<<endl;
    }
    

    return 0;
}

例题

1. bzoj4033[HAOI2015

1. bzoj4033[HAOI2015] 树上染色 dp

2. C 小G砍树

C 小G砍树
本题的公式

const int maxn = 1e5+10;
vector<int> G[maxn];
LL siz[maxn],mut[maxn];
LL ans[maxn];
LL inv[maxn];
int n;
void dfs1(int node,int fa = -1){
    siz[node] = 1;
    mut[node] = 1;
    for(auto c: G[node]){
        if(c == fa) continue;
        dfs1(c,node);
        siz[node] += siz[c];
        mut[node] = 1ll*mut[node]*mut[c]%mod;
    }
    mut[node] = mut[node]*siz[node]%mod;
}
void dfs2(int node,int fa = -1){
    for(auto c: G[node]){
        if(c == fa) continue;
        ans[c] = 1ll*(n-siz[c])*ans[node]%mod*inv[siz[c]]%mod;
        dfs2(c,node);
    }
}
int main(void)
{
    inv[1] = 1;
    for(int i = 2;i < maxn; ++i)
      inv[i] = (mod-1ll*inv[mod%i]*(mod/i)%mod)%mod;

    cin>>n;
    for(int i = 1,u,v;i < n; ++i){
        scanf("%d%d",&u,&v);
        G[u].Pb(v);
        G[v].Pb(u);
    }
    dfs1(1);
    ans[1] = mut[1];
    dfs2(1);
    LL tmp = 0;
    for(int i = 1;i <= n; ++i)
        tmp = (tmp+qpow(ans[i],mod-2))%mod;
    for(int i = 1;i <= n; ++i)
        tmp = tmp*i%mod;
    cout<<tmp<<endl;
    return 0;
}