D、树上祖先询问题解

前置知识:dfs序,倍增法求lca

OI wiki上的倍增LCA模板, 看文字嫌累还有b站的讲解视频(个人觉得不错)

题目主要是让我们求A集合内全部点的公共祖先,对于求多个点的最近公共祖先,我们并不真的要对所有点两两都求一次。我们只用取这些点中dfs序最小和最大的两个点来求最近公共祖先就行。 证明可以看看去如何在 DAG 中找多个点的 LCA ? 看,这里不多做补充。

问题解析

知道了这两个前置知识后,这题就变的很简单了。

  1. 我们可以先对所有的点来一遍前序遍历,求得他们的dfs序。
  2. 用A集合中dfs序中最小的点和最大的点来求得A集合的最近公共祖先z
  3. 看A集合中是否有z,因为我们要的是所有点的节点,如果z出现在A集合,那我们再求z的父亲节点就行
  4. 看B集合中有没有点z出现,如果有,输出yes,如果没有,输出no

但是!第一次交并没有AC!

#include<iostream>
using namespace std;
#include<vector>
#include<algorithm>
#include<math.h>
#include<set>
#include <random>
#include<numeric>
#include<string>
#include<string.h>
#include<iterator>
#include<fstream>
#include<map>
#include<unordered_map>
#include<stack>
#include<list>
#include<queue>
#include<iomanip>
#include<bitset>

//#pragma GCC optimize(3)

#define endl '\n'
#define int ll
#define PI acos(-1)
#define INF 0x3f3f3f3f
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll>PII;
const int N = 1e5 + 50, MOD = 1e11 + 3;

int qpow(int a, int b)
{
    int res = 1;
    while (b)
    {
        if (b & 1)res = (1LL) * res * a % MOD;
        b >>= 1;
        a = (1LL) * a * a % MOD;
    }
    return res;
}

int deep[N], fa[N][31], st[N];
vector<int>tree[N];

// dfs,用来为 lca 算法做准备。接受两个参数:dfs 起始节点和它的父亲节点。
void dfs(int root, int fno,int &cnt) {
    // 初始化:第 2^0 = 1 个祖先就是它的父亲节点,dep 也比父亲节点多 1。
    fa[root][0] = fno;
    st[root] = cnt++;
    deep[root] = deep[fa[root][0]] + 1;
    // 初始化:其他的祖先节点:第 2^i 的祖先节点是第 2^(i-1) 的祖先节点的第
    // 2^(i-1) 的祖先节点。
    for (int i = 1; i < 31; ++i) {
        fa[root][i] = fa[fa[root][i - 1]][i - 1];
    }
    // 遍历子节点来进行 dfs。
    int sz = tree[root].size();
    for (int i = 0; i < sz; ++i) {
        if (tree[root][i] == fno) continue;
        dfs(tree[root][i], root, cnt);
    }
}

// lca。用倍增算法算取 x 和 y 的 lca 节点。
int lca(int x, int y) {
    // 令 y 比 x 深。
    if (deep[x] > deep[y]) swap(x, y);
    // 令 y 和 x 在一个深度。
    int tmp = deep[y] - deep[x], ans = 0;
    for (int j = 0; tmp; ++j, tmp >>= 1)
        if (tmp & 1) y = fa[y][j];
    // 如果这个时候 y = x,那么 x,y 就都是它们自己的祖先。
    if (y == x) return x;
    // 不然的话,找到第一个不是它们祖先的两个点。
    for (int j = 30; j >= 0 && y != x; --j) {
        if (fa[x][j] != fa[y][j]) {
            x = fa[x][j];
            y = fa[y][j];
        }
    }
    // 返回结果。
    return fa[x][0];
}

bool cmp(int& a, int& b)
{
    return st[a] < st[b];
}

void solve()
{
    int n, q;
    cin >> n >> q;
    for (int i = 2; i <= n; i++)
    {
        int x;
        cin >> x;
        tree[i].push_back(x);
        tree[x].push_back(i);
    }
    int cnt = 1;
    dfs(1, 0, cnt);
    while (q--)
    {
        int a, b;
        cin >> a >> b;
        vector<int>v1(a), v2(b);
        for (int i = 0; i < a; i++)cin >> v1[i];
        for (int i = 0; i < b; i++)cin >> v2[i];
        //按照dfs序大小升序排序
        sort(v1.begin(), v1.end(), cmp);
        //获得dfs序最小的点和最大的点
        int x = v1[0], y = v1.back();
        int z = lca(x, y);
        bool flag = false;
        //判断z是否出现在集合A中
        for (auto i : v1)
        {
            if (z == i)
            {
                flag = true;
                break;
            }
        }
        if (flag)
        {
        	//如果出现了,求z点的父亲节点
            z = fa[z][0];
        }
        flag = false;
        //再看B集合中有没有出现过z即可
        for (auto i : v2)
        {
            if (i == z)
            {
                flag = true;
                break;
            }
        }
        if (flag)cout << "yes" << endl;
        else cout << "no" << endl;
    }
}

signed main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int t = 1;
    //cin >> t;
    while (t--)
    {
        solve();
    }
    return 0;
}

原来这题不光是求“最近”的公共祖先,而是只要是祖先就行。哪怕这个点是z的爷爷或太爷爷,那也是满足条件的。

那我们就先用哈希表记录B集合全部的点,算出来A集合的LCA后,我们看z点的祖先节点是否出现在哈希表里即可。

AC代码

#include<iostream>
using namespace std;
#include<vector>
#include<algorithm>
#include<math.h>
#include<set>
#include <random>
#include<numeric>
#include<string>
#include<string.h>
#include<iterator>
#include<fstream>
#include<map>
#include<unordered_map>
#include<stack>
#include<list>
#include<queue>
#include<iomanip>
#include<bitset>

//#pragma GCC optimize(3)

#define endl '\n'
#define int ll
#define PI acos(-1)
#define INF 0x3f3f3f3f
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll>PII;
const int N = 1e5 + 50, MOD = 1e11 + 3;

int qpow(int a, int b)
{
    int res = 1;
    while (b)
    {
        if (b & 1)res = (1LL) * res * a % MOD;
        b >>= 1;
        a = (1LL) * a * a % MOD;
    }
    return res;
}

int deep[N], fa[N][31], st[N];
vector<int>tree[N];

// dfs,用来为 lca 算法做准备。接受两个参数:dfs 起始节点和它的父亲节点。
void dfs(int root, int fno,int &cnt) {
    // 初始化:第 2^0 = 1 个祖先就是它的父亲节点,dep 也比父亲节点多 1。
    fa[root][0] = fno;
    st[root] = cnt++;
    deep[root] = deep[fa[root][0]] + 1;
    // 初始化:其他的祖先节点:第 2^i 的祖先节点是第 2^(i-1) 的祖先节点的第
    // 2^(i-1) 的祖先节点。
    for (int i = 1; i < 31; ++i) {
        fa[root][i] = fa[fa[root][i - 1]][i - 1];
    }
    // 遍历子节点来进行 dfs。
    int sz = tree[root].size();
    for (int i = 0; i < sz; ++i) {
        if (tree[root][i] == fno) continue;
        dfs(tree[root][i], root, cnt);
    }
}

// lca。用倍增算法算取 x 和 y 的 lca 节点。
int lca(int x, int y) {
    // 令 y 比 x 深。
    if (deep[x] > deep[y]) swap(x, y);
    // 令 y 和 x 在一个深度。
    int tmp = deep[y] - deep[x], ans = 0;
    for (int j = 0; tmp; ++j, tmp >>= 1)
        if (tmp & 1) y = fa[y][j];
    // 如果这个时候 y = x,那么 x,y 就都是它们自己的祖先。
    if (y == x) return x;
    // 不然的话,找到第一个不是它们祖先的两个点。
    for (int j = 30; j >= 0 && y != x; --j) {
        if (fa[x][j] != fa[y][j]) {
            x = fa[x][j];
            y = fa[y][j];
        }
    }
    // 返回结果。
    return fa[x][0];
}

bool cmp(int& a, int& b)
{
    return st[a] < st[b];
}

void solve()
{
    int n, q;
    cin >> n >> q;
    for (int i = 2; i <= n; i++)
    {
        int x;
        cin >> x;
        tree[i].push_back(x);
        tree[x].push_back(i);
    }
    int cnt = 1;
    dfs(1, 0, cnt);
    while (q--)
    {
        int a, b;
        cin >> a >> b;
        vector<int>v1(a), v2(b);
        for (int i = 0; i < a; i++)cin >> v1[i];
        for (int i = 0; i < b; i++)cin >> v2[i];
        sort(v1.begin(), v1.end(), cmp);
        unordered_map<int,int>mymap;
        //哈希表记录集合B的点
        for(auto i:v2)mymap[i]=1;
        int x = v1[0], y = v1.back();
        int z = lca(x, y);
        bool flag = false;
        for (auto i : v1)
        {
            if (z == i)
            {
                flag = true;
                break;
            }
        }
        if (flag)
        {
            z = fa[z][0];
        }
        if(z==0)
        {
            cout << "no" << endl;
            continue;
        }
        flag = false;
        while(z!=0)
        {
        	//判断z点是否出现在B集合
            if(mymap.count(z))
            {
                flag=true;
                break;
            }
            //把z点逐步向上移动
            z=fa[z][0];
        }
        if (flag)cout << "yes" << endl;
        else cout << "no" << endl;
    }
}

signed main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int t = 1;
    //cin >> t;
    while (t--)
    {
        solve();
    }
    return 0;
}