Ancestor

题目描述:

给你两棵树A和B,点的编号从1n,根结点是1,且每个点都有一个价值,现在给你k个点,选任意k-1个不同的点,分别求这些点在两颗树上的最近公共祖先fa, fb,问存在多少种情况满足A树上fa的价值大于B树上fb的价值

思路1:前缀和 + 后缀和

这是最简单好写的一个思路,考虑一个树,我们删除第i个点后剩下k-1个点的最近公共祖先就是前i-1个点和后n-i个点的最近公共祖先,可以维护一个最近公共祖先的前缀和和后缀和来直接计算

具体的见代码

#include <bits/stdc++.h>
using namespace std;

#define endl '\n'
#define inf 0x3f3f3f3f
#define mod 998244353
#define m_p(a,b) make_pair(a, b)
#define mem(a,b) memset((a),(b),sizeof(a))
#define io ios::sync_with_stdio(false); cin.tie(0); cout.tie(0)

typedef long long ll;
typedef pair <int,int> pii;

#define MAX 300000 + 50
int n, m, k, op, x;
int kr[MAX];

int va[MAX];//a树的价值
int pra[MAX];//前缀和
int sua[MAX];//后缀和
vector<int>ar[MAX];//建图
int ha[MAX];//点在树上的高度
int fa[MAX][25];//倍增
void dfs(int u, int ff){
    for(int i = 1; i <= 20; ++i){
        fa[u][i] = fa[fa[u][i - 1]][i - 1];
    }
    for(auto v : ar[u]){
        if(v == ff)continue;
        ha[v] = ha[u] + 1;
        fa[v][0] = u;
        dfs(v, u);
    }
}
int lca(int x, int y){
    if(x == 0)return y;//用于解决i-1=0或者i+1=n+1的情况
    if(y == 0)return x;
    if(ha[x] < ha[y])swap(x, y);
    for(int i = 20; i >= 0; --i){
        if(ha[fa[x][i]] >= ha[y])x = fa[x][i];
    }
    if(x == y)return x;
    for(int i = 20; i >= 0; --i){
        if(fa[x][i] != fa[y][i]){
            x = fa[x][i];
            y = fa[y][i];
        }
    }
    return fa[x][0];
}

int vb[MAX];//含义同上
int prb[MAX];
int sub[MAX];
vector<int>br[MAX];
int hb[MAX];
int fb[MAX][25];
void ddfs(int u, int ff){
    for(int i = 1; i <= 20; ++i){
        fb[u][i] = fb[fb[u][i - 1]][i - 1];
    }
    for(auto v : br[u]){
        if(v == ff)continue;
        hb[v] = hb[u] + 1;
        fb[v][0] = u;
        ddfs(v, u);
    }
}
int lcb(int x, int y){
    if(x == 0)return y;
    if(y == 0)return x;
    if(hb[x] < hb[y])swap(x, y);
    for(int i = 20; i >= 0; --i){
        if(hb[fb[x][i]] >= hb[y])x = fb[x][i];
    }
    if(x == y)return x;
    for(int i = 20; i >= 0; --i){
        if(fb[x][i] != fb[y][i]){
            x = fb[x][i];
            y = fb[y][i];
        }
    }
    return fb[x][0];
}

void work(){
    cin >> n >> k;
    for(int i = 1; i <= k; ++i){
        cin >> kr[i];
    }
    for(int i = 1; i <= n; ++i){
        cin >> va[i];
    }
    for(int i = 2; i <= n; ++i){
        cin >> x;
        ar[x].push_back(i);
        ar[i].push_back(x);
    }
    for(int i = 1; i <= n; ++i)cin >> vb[i];
    for(int i = 2; i <= n; ++i){
        cin >> x;
        br[i].push_back(x);
        br[x].push_back(i);
    }
    //求lca和前缀和和后缀和
    ha[1] = 1;
    fa[1][0] = 0;
    dfs(1, -1);
    int p = kr[1];
    for(int i = 1; i <= k; ++i){
        p = lca(p, kr[i]);
        pra[i] = p;
    }
    p = kr[n];
    for(int i = k; i >= 1; --i){
        p = lca(p, kr[i]);
        sua[i] = p;
    }
    
    hb[1] = 1;
    fb[1][0] = 0;
    ddfs(1, -1);
    p = kr[1];
    for(int i = 1; i <= k; ++i){
        p = lcb(p, kr[i]);
        prb[i] = p;
    }
    p = kr[n];
    for(int i = k; i >= 1; --i){
        p = lcb(p, kr[i]);
        sub[i] = p;
    }
    int ans = 0;
    for(int i = 1; i <= k; ++i){
        if(va[lca(pra[i - 1], sua[i + 1])] > vb[lcb(prb[i - 1], sub[i + 1])])++ans;
    }
    cout << ans << endl;
}

int main(){
    io;
    work();
    return 0;
}

思路2:DFS序 + LCA

有一个结论是:求树上多个点的LCA的方法是求这些点在树上的dfs序的最小值的点和最大值的点的LCA

所以我们求出来两颗树的dfs序后,将k个点按照dfs序的大小来排序后,离线处理,再按照上面的结论进行计算就行

具体的看代码

#include <bits/stdc++.h>
using namespace std;

#define endl '\n'
#define inf 0x3f3f3f3f
#define mod 998244353
#define m_p(a,b) make_pair(a, b)
#define mem(a,b) memset((a),(b),sizeof(a))
#define io ios::sync_with_stdio(false); cin.tie(0); cout.tie(0)

typedef long long ll;
typedef pair <int,int> pii;

#define MAX 300000 + 50
int n, m, k, op, x;
int tr[MAX];
int kr[MAX];

int a;
int vva[MAX];
int ida[MAX];//记录k个数字删除时其余点的lca
vector<int>ar[MAX];
int ha[MAX];
int fa[MAX][25];
void dfs(int u, int ff){
    ida[u] = ++a;
    for(int i = 1; i <= 20; ++i){
        fa[u][i] = fa[fa[u][i - 1]][i - 1];
    }
    for(auto v : ar[u]){
        if(v == ff)continue;
        ha[v] = ha[u] + 1;
        fa[v][0] = u;
        dfs(v, u);
    }
}
int lca(int x, int y){
    if(x == 0)return y;
    if(y == 0)return x;
    if(ha[x] < ha[y])swap(x, y);
    for(int i = 20; i >= 0; --i){
        if(ha[fa[x][i]] >= ha[y])x = fa[x][i];
    }
    if(x == y)return x;
    for(int i = 20; i >= 0; --i){
        if(fa[x][i] != fa[y][i]){
            x = fa[x][i];
            y = fa[y][i];
        }
    }
    return fa[x][0];
}

int b;
int vvb[MAX];
int idb[MAX];
vector<int>br[MAX];
int hb[MAX];
int fb[MAX][25];
void ddfs(int u, int ff){
    idb[u] = ++b;
    for(int i = 1; i <= 20; ++i){
        fb[u][i] = fb[fb[u][i - 1]][i - 1];
    }
    for(auto v : br[u]){
        if(v == ff)continue;
        hb[v] = hb[u] + 1;
        fb[v][0] = u;
        ddfs(v, u);
    }
}
int lcb(int x, int y){
    if(x == 0)return y;
    if(y == 0)return x;
    if(hb[x] < hb[y])swap(x, y);
    for(int i = 20; i >= 0; --i){
        if(hb[fb[x][i]] >= hb[y])x = fb[x][i];
    }
    if(x == y)return x;
    for(int i = 20; i >= 0; --i){
        if(fb[x][i] != fb[y][i]){
            x = fb[x][i];
            y = fb[y][i];
        }
    }
    return fb[x][0];
}

int ***a[MAX];
int ***b[MAX];

void work(){
    cin >> n >> k;
    for(int i = 1; i <= k; ++i){
        cin >> kr[i];
    }
    for(int i = 1; i <= n; ++i){
        cin >> vva[i];
    }
    for(int i = 2; i <= n; ++i){
        cin >> x;
        ar[x].push_back(i);
        ar[i].push_back(x);
    }
    for(int i = 1; i <= n; ++i)cin >> vvb[i];
    for(int i = 2; i <= n; ++i){
        cin >> x;
        br[i].push_back(x);
        br[x].push_back(i);
    }
    
    ha[1] = 1;
    fa[1][0] = 0;
    dfs(1, -1);
    vector<pii>va;
    for(int i = 1; i <= k; ++i)va.push_back(m_p(ida[kr[i]], kr[i]));
    sort(va.begin(), va.end());
    
    hb[1] = 1;
    fb[1][0] = 0;
    ddfs(1, -1);
    vector<pii>vb;
    for(int i = 1; i <= k; ++i)vb.push_back(m_p(idb[kr[i]], kr[i]));
    sort(vb.begin(), vb.end());
    int ans = 0;
    for(int i = 0; i < k; ++i){
        int a = lca(i == 0 ? va[1].second : va[0].second, i == k - 1 ? va[(int)va.size() - 2].second : va.back().second);
        int b = lcb(i == 0 ? vb[1].second : vb[0].second, i == k - 1 ? vb[(int)vb.size() - 2].second : vb.back().second);
        a[va[i].second] = a;
        b[vb[i].second] = b;
    }
    for(int i = 1; i <= k; ++i){
        if(vva[a[kr[i]]] > vvb[b[kr[i]]])++ans;
    }
    cout << ans << endl;
    
}

int main(){
    io;
    work();
    return 0;
}

思路3:思维+dfs

很不幸,在比赛的时候上面的两个方法我们队伍都没想到,而是想到了下面的这个思路,并且由我写了一个非常非常臭的代码,最后是写了230行

k个点的最近公共祖先root我们可以直接求出来,而在绝大多数情况下,删任意一个点都不会改变这个root,会改变的情况只有两种:

  • k-1个节点在root的一颗子树上,另一个节点在另一个棵子树上的时候,此时删除这个单独的点可能会使的lca发生变化,例如这个图中的3,删掉后,45lca就变成了4,而不是之前的1

    alt

    image-20220725232510332
  • k个节点都在root的一颗子树上,而root一定是k个点中的一个,此时删除这个root,是会使得lca发生改变,例如如果删除图中的1号点,则46lca就是2,而不是之前的1

    image-20220725232655082alt

这里其实就不需要我们用什么倍增求LCA,我们只需要dfs一遍树,记录每个节点的子树中(包括这个点本身)含有的特殊点的数量,找到第一个数量等于k的节点,这个节点就是k个点的lca,我们同时记录一下这个点含有特殊点的子树的数量以及子树中含有特殊点的数量,判断一下数量等于1或者2或者其他,

  • 如果是1,就需要删除这个root后再跑一遍dfs求剩下k-1个点的lca来确定root的答案,而其他k-1个点的答案就root,(这里的答案指的是删除这个点后剩下的k-1个点的lca
  • 如果是2,就需要删除那个单独位于root的一个子树的点,来再跑一次dfs后确定他的答案,而其他点的答案就是root
  • 如果是其他,那k个点的答案就都是root

具体的看代码…(数组很多,因为不是很喜欢传一些参数来区分两种树,所以函数写的很多

#include <bits/stdc++.h>
using namespace std;

#define endl '\n'
#define inf 0x3f3f3f3f
#define mod7 1000000007
#define mod9 998244353
#define m_p(a,b) make_pair(a, b)
#define mem(a,b) memset((a),(b),sizeof(a))
#define io ios::sync_with_stdio(false); cin.tie(0); cout.tie(0)
#define debug(a) cout << "Debuging...|" << #a << ": " << a << "\n";
typedef long long ll;
typedef pair <int,int> pii;

#define MAX 300000 + 50
int n, m, k, x;
int tr[MAX];
vector<int>kr;//存k个点
vector<int>ar[MAX];//存图
int va[MAX];//存价值
vector<int>br[MAX];//
int vb[MAX];
int ffa[MAX], ffb[MAX];//记录每个点的父节点

set<int>se;//存特殊点,用于判断当前点是不是特殊点
int num1[MAX];//记录子树中特殊点的数量
vector<pii>aa, bb;//存root的各个含特殊点子树的特殊点的数量以及根节点,便于排序确定最少的特殊点的子树
int root1, root2;//记录lca
void dfs(int u, int fa){
    if(se.count(u))++num1[u];
    for(auto v : ar[u]){
        if(v == fa)continue;
        dfs(v, u);
        num1[u] += num1[v];
    }
    if(root1 == 0 && num1[u] == k){//找到root
        root1 = u;
        for(auto v : ar[u]){
            if(v == fa)continue;
            if(num1[v])aa.push_back(m_p(num1[v], v));
        }
    }
}

int num2[MAX];
void ddfs(int u, int fa){
    if(se.count(u))++num2[u];
    for(auto v : br[u]){
        if(v == fa)continue;
        ddfs(v, u);
        num2[u] += num2[v];
    }
    if(root2 == 0 && num2[u] == k){
        root2 = u;
        for(auto v : br[u]){
            if(v == fa)continue;
            if(num2[v])bb.push_back(m_p(num2[v], v));
        }
    }
}

void dddfs(int u, int fa){//用于第二次删点后跑图找k-1个点的lca
    if(se.count(u))++num1[u];
    for(auto v : ar[u]){
        if(v == fa)continue;
        dddfs(v, u);
        num1[u] += num1[v];
    }
    if(root1 == 0 && num1[u] == k - 1){
        root1 = u;
    }
}

void ddddfs(int u, int fa){
    if(se.count(u))++num2[u];
    for(auto v : br[u]){
        if(v == fa)continue;
        ddddfs(v, u);
        num2[u] += num2[v];
    }
    if(root2 == 0 && num2[u] == k - 1){
        root2 = u;
    }
}

int ansa[MAX];//记录上面说的"答案"
int ansb[MAX];
int ***a, ***b;//数量等于2的那个单独的点的节点下标
void dddddfs(int u, int fa){//去dfs找数量等于2的那个单独的点的节点下标
    if(se.count(u)){
        ***a = u;
        return;
    }
    for(auto v : ar[u]){
        if(v == fa)continue;
        dddddfs(v, u);
    }
}
void ddddddfs(int u, int fa){
    if(se.count(u)){
        ***b = u;
        return;
    }
    for(auto v : br[u]){
        if(v == fa)continue;
        ddddddfs(v, u);
    }
}

void geta(){//计算"答案"
    if((int)aa.size() == 1){//判断1
        for(auto x : kr){//先确定其他点的答案
            if(x != root1)ansa[x] = root1;
        }
        int id = root1;
        se.erase(id);//删点
        mem(num1, 0);
        root1 = 0;
        dddfs(1, -1);
        ansa[id] = root1;
        se.insert(id);//再塞回去
    }
    else if((int)aa.size() == 2){//判断2
        sort(aa.begin(), aa.end());
        auto [num, id] = aa.front();
        if(num == 1){
            dddddfs(id, ffa[id]);
            id = ***a;
            for(auto x : kr){
                if(x != id)ansa[x] = root1;
            }
            se.erase(id);
            root1 = 0;
            mem(num1, 0);
            dddfs(1, -1);
            ansa[id] = root1;
            se.insert(id);
        }
        else{
            for(auto x : kr)ansa[x] = root1;
        }
    }
    else{
        for(auto x : kr)ansa[x] = root1;
    }
}

void getb(){
    if((int)bb.size() == 1){
        for(auto x : kr){
            if(x != root2)ansb[x] = root2;
        }
        int id = root2;
        se.erase(id);
        mem(num2, 0);
        root2 = 0;
        ddddfs(1, -1);
        ansb[id] = root2;
        se.insert(id);
    }
    else if((int)bb.size() == 2){
        sort(bb.begin(), bb.end());
        auto [num, id] = bb.front();
        if(num == 1){
            ddddddfs(id, ffb[id]);
            id = ***b;
            for(auto x : kr){
                if(x != id)ansb[x] = root2;
            }
            se.erase(id);
            root2 = 0;
            mem(num2, 0);
            ddddfs(1, -1);
            ansb[id] = root2;
            se.insert(id);
        }
        else{
            for(auto x : kr)ansb[x] = root2;
        }
    }
    else{
        for(auto x : kr)ansb[x] = root2;
    }
}


void work(){
    cin >> n >> k;
    for(int i = 1; i <= k; ++i){
        cin >> x;
        se.insert(x);
        kr.push_back(x);
    }
    for(int i = 1; i <= n; ++i)cin >> va[i];
    for(int i = 2; i <= n; ++i){
        cin >> x;
        ffa[i] = x;
        ar[x].push_back(i);
        ar[i].push_back(x);
    }
    for(int i = 1; i <= n; ++i)cin >> vb[i];
    for(int i = 2; i <= n; ++i){
        cin >> x;
        ffb[i] = x;
        br[x].push_back(i);
        br[i].push_back(x);
    }
    if(k == 2){
        int ans = 0;
        if(va[kr[0]] > vb[kr[0]])++ans;
        if(va[kr[1]] > vb[kr[1]])++ans;
        cout << ans << endl;
        return;
    }
    dfs(1, -1);
    ddfs(1, -1);
    geta();
    getb();
    int ans = 0;
    for(auto x : kr){
        if(va[ansa[x]] > vb[ansb[x]])++ans;
    }
    cout << ans << endl;
}

int main(){
    io;
    work();
    return 0;
}