题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6065

题意:给一颗有根树,根节点为1,再给定一个排列,长度为n,要求将排列切分成K段,定义每段的价值为该排列所有点及两两点之间lca中最浅节点的深度。要求输出K段区间所有可能的价值和中的最小值。n*K<=3e5。

解法:参考http://blog.csdn.net/u013944294/article/details/76601946

这里主要考虑LCA有几个强力的性质:

1,定义“一段排列所有点及两两点之间lca中最浅节点的深度”为T,当在排列末尾加上一个节点ai的时候,只需要求一下ai-1与ai的lca,再与之前的lca比较谁的深度小,维护深度的最小值即可。

2,处理出相邻点间lca深度之后,比如 7 4 5 6 3 9 10,若最优切分方式里 4 与 3 在不同区间里,4 与 3 之间的任何数,即 5 6,划分给 3 区间或者 4 区间,都不会影响最终答案

3,区间末尾增加新的节点时,价值T一定是不增的。


所以当前的DP值就之和相邻的前2两个DP值有关,因为划分出来的每个区间的答案,其实就是连续两个的lca的最小值。

所以:DP转移有两种:

1) 将该点放到前一个区间里,dp[i][j]=dp[i-1][j];

2) 将该点放到下一个区间里,dp[i][j]=dp[i-1][j-1]+depth[j] 或者 dp[i][j]=dp[i-2][j-1]+depth[lca(i-1,i)]。

复杂度:O(n*k)


109ms代码:


#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
struct FastIO
{
    static const int S = 1310720;
    int wpos;
    char wbuf[S];
    FastIO() : wpos(0) {}
    inline int xchar()
    {
        static char buf[S];
        static int len = 0, pos = 0;
        if (pos == len)
            pos = 0, len = fread(buf, 1, S, stdin);
        if (pos == len) exit(0);
        return buf[pos ++];
    }
    inline int xuint()
    {
        int c = xchar(), x = 0;
        while (c <= 32) c = xchar();
        for (; '0' <= c && c <= '9'; c = xchar()) x = x * 10 + c - '0';
        return x;
    }
    inline int xint()
    {
        int s = 1, c = xchar(), x = 0;
        while (c <= 32) c = xchar();
        if (c == '-') s = -1, c = xchar();
        for (; '0' <= c && c <= '9'; c = xchar()) x = x * 10 + c - '0';
        return x * s;
    }
    inline void xstring(char *s)
    {
        int c = xchar();
        while (c <= 32) c = xchar();
        for (; c > 32; c = xchar()) * s++ = c;
        *s = 0;
    }
    inline void wchar(int x)
    {
        if (wpos == S) fwrite(wbuf, 1, S, stdout), wpos = 0;
        wbuf[wpos ++] = x;
    }
    inline void wint(LL x)
    {
        if (x < 0) wchar('-'), x = -x;
        char s[24];
        int n = 0;
        while (x || !n) s[n ++] = '0' + x % 10, x /= 10;
        while (n--) wchar(s[n]);
    }
    inline void wstring(const char *s)
    {
        while (*s) wchar(*s++);
    }
    ~FastIO()
    {
        if (wpos) fwrite(wbuf, 1, wpos, stdout), wpos = 0;
    }
} io;
const int maxn = 3e5+5;
const int inf = 0x3f3f3f3f;
int n,k,head[maxn],p[maxn],edgecnt;
int fa[maxn][20],dep[maxn],d[maxn];
struct edge{
    int to,next;
}E[maxn*2];
void init(){
    memset(head,-1,sizeof(head));
    edgecnt=0;
}
void add(int u, int v){
    E[edgecnt].to = v, E[edgecnt].next = head[u], head[u] = edgecnt++;
}
void dfs(int x, int d, int pre){
    dep[x] = d;
    fa[x][0] = pre;
    for(int i=1; i<20; i++){
        fa[x][i] = fa[fa[x][i-1]][i-1];
    }
    for(int i=head[x]; ~i; i=E[i].next){
        int v = E[i].to;
        if(v == pre) continue;
        dfs(v, d+1, x);
    }
}
int LCA(int u, int v){
    if(dep[u]<dep[v]) swap(u,v);
    for(int i=19; i>=0; i--){
        if(dep[fa[u][i]]>=dep[v]){
            u=fa[u][i];
        }
    }
    if(u==v) return u;
    for(int i=19; i>=0; i--){
        if(fa[u][i]!=fa[v][i]){
            u=fa[u][i];
            v=fa[v][i];
        }
    }
    return fa[u][0];
}
int main()
{
    while(1)
    {
        n = io.xint();
        k = io.xint();
        for(int i=1; i<=n; i++) p[i] = io.xint();
        init();
        for(int i=1; i<n; i++){
            int u, v;
            u = io.xint();
            v = io.xint();
            add(u,v);
            add(v,u);
        }
        dfs(1,1,0);
        for(int i=2; i<=n; i++){
            int _lca=LCA(p[i],p[i-1]);
            d[i]=dep[_lca];
        }
        vector<vector<int> >dp(n+1,vector<int>(k+1,0));
        for(int i=1; i<=n; i++){
            for(int j=1; j<=min(i,k); j++){
                int ret=inf;
                if(i>=2&&j-1<=i-2) ret=min(ret,dp[i-2][j-1]+d[i]);
                if(j<i) ret = min(ret, dp[i-1][j]);
                ret = min(ret, dp[i-1][j-1]+dep[p[i]]);
                dp[i][j] = ret;
            }
        }
        printf("%d\n", dp[n][k]);
    }
    return 0;
}