A、Ancestor题解
前置知识:dfs序,倍增法求lca
OI wiki上的倍增LCA模板, 看文字嫌累还有b站的讲解视频(个人觉得不错)
题目主要是让我们求长度为k-1的点集的最近公共祖先(一个点被删了),对于求多个点的最近公共祖先,我们并不真的要对所有点两两都求一次。我们只用取这些点中dfs序最小和最大的两个点来求最近公共祖先就行。 证明可以看看去如何在 DAG 中找多个点的 LCA ? 看,这里不多做补充。
问题解析
知道了这两个前置知识后,这题就变的很简单了。
-
我们可以先对所有的点来一遍前序遍历,求得他们的dfs序。
-
根据dfs序对k集合的点进行升序排序(两个树的dfs序不一定一样的,所以我们两边都要求嗷,排序也是)。
-
然后我们枚举第1~第k个点作为被删除的点,对于这两个树,我们取他们k集合中dfs序最大和最小的两个点来分别求lca。如果它们中有点正好是这次被删除的点,那我们就取第二大(第二小)的点(哪个被删改哪个,没被删就不理)。
-
之后比较两边求出来的点,看A树的祖先的权值是否大于B树的祖先,如果大于,则计数器++。
样例解释
输入:
5 3
5 4 3
6 6 3 4 6
1 2 2 4
7 4 5 7 7
1 1 3 2
输出:
1
黑色是点的编号,红色是每个点的权值,蓝色是他们的dfs序
-
树A的k集合排序后为:3 4 5.
-
树B的k集合排序后为:5 3 4.
以A的k集合为准,我们枚举删除的点:
-
删除点3——A的最大dfs序点为5,最小为4;B的最大点为4,最小为5
求得A的祖先为点4,权值为4,B为1,权值为7,不满足,cnt=0;
-
删除点4——A的最大dfs序点为5,最小为3;B的最大点为3,最小为5
求得A的祖先为点2,权值为6,B为1,权值为7,不满足,cnt=0;
-
删除点5——A的最大dfs序点为4,最小为3;B的最大点为4,最小为3
求得A的祖先为点2,权值为6,B为3,权值为4,满足,cnt=1;
复杂度分析
倍增lca预处理复杂度为:n*logn.
对k集合点排序:k*logk.
处理单次询问:logn.
AC代码
(代码很丑很长,主要是学艺不精重复的地方有点多不然可以很简洁的,但重要的是思路)
#include<iostream>
using namespace std;
#include<vector>
#include<algorithm>
#include<math.h>
#include<set>
#include<numeric>
#include<string>
#include<string.h>
#include<iterator>
#include<queue>
#define endl '\n'
#define int ll
#define PI acos(-1)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll>PII;
const int N = 2e5 + 50, MOD = 1e9 + 7;
//deep数组记录的是树上每个节点的深度,fa数组记录的是每个点的父节点,这两个是倍增lca需要的
int deepA[N], faA[N][31], deepB[N], faB[N][31];
//st数组记录每个点的dfs序,w数组记录每个点的权值
int stA[N], stB[N], wA[N], wB[N];
//树数组
vector<int>treeA[N], treeB[N];
bool cmpA(int a, int b)
{
return stA[a] < stA[b];
}
bool cmpB(int a, int b)
{
return stB[a] < stB[b];
}
// dfs,用来为 lca 算法做准备。接受两个参数:dfs 起始节点和它的父亲节点。
void dfsA(int root, int fno, int& cnt) {
// 初始化:第 2^0 = 1 个祖先就是它的父亲节点,dep 也比父亲节点多 1。
faA[root][0] = fno;
//记录当前点的dfs序
stA[root] = cnt++;
deepA[root] = deepA[faA[root][0]] + 1;
// 初始化:其他的祖先节点:第 2^i 的祖先节点是第 2^(i-1) 的祖先节点的第
// 2^(i-1) 的祖先节点。
for (int i = 1; i < 31; ++i) {
faA[root][i] = faA[faA[root][i - 1]][i - 1];
}
// 遍历子节点来进行 dfs。
int sz = treeA[root].size();
for (int i = 0; i < sz; ++i) {
if (treeA[root][i] == fno) continue;
dfsA(treeA[root][i], root, cnt);
}
}
void dfsB(int root, int fno, int& cnt) {
faB[root][0] = fno;
stB[root] = cnt++;
deepB[root] = deepB[faB[root][0]] + 1;
for (int i = 1; i < 31; ++i) {
faB[root][i] = faB[faB[root][i - 1]][i - 1];
}
int sz = treeB[root].size();
for (int i = 0; i < sz; ++i) {
if (treeB[root][i] == fno) continue;
dfsB(treeB[root][i], root, cnt);
}
}
// lca。用倍增算法算取 x 和 y 的 lca 节点。
int lcaA(int x, int y) {
// 令 y 比 x 深。
if (deepA[x] > deepA[y]) swap(x, y);
// 令 y 和 x 在一个深度。
int tmp = deepA[y] - deepA[x], ans = 0;
for (int j = 0; tmp; ++j, tmp >>= 1)
if (tmp & 1) y = faA[y][j];
// 如果这个时候 y = x,那么 x,y 就都是它们自己的祖先。
if (y == x) return x;
// 不然的话,找到第一个不是它们祖先的两个点。
for (int j = 30; j >= 0 && y != x; --j) {
if (faA[x][j] != faA[y][j]) {
x = faA[x][j];
y = faA[y][j];
}
}
// 返回结果。
return faA[x][0];
}
int lcaB(int x, int y) {
if (deepB[x] > deepB[y]) swap(x, y);
int tmp = deepB[y] - deepB[x], ans = 0;
for (int j = 0; tmp; ++j, tmp >>= 1)
if (tmp & 1) y = faB[y][j];
if (y == x) return x;
for (int j = 30; j >= 0 && y != x; --j) {
if (faB[x][j] != faB[y][j]) {
x = faB[x][j];
y = faB[y][j];
}
}
return faB[x][0];
}
void solve()
{
int n, m, x;
cin >> n >> m;
vector<int>ansA(m), ansB(m);
for (int i = 0; i < m; i++)cin >> ansA[i], ansB[i] = ansA[i];
for (int i = 1; i <= n; i++)cin >> wA[i];
for (int i = 2; i <= n; i++)
{
cin >> x;
treeA[x].push_back(i);
}
int cnt = 1;
//预处理lca,并且求得每个点的dfs序
dfsA(1, 0, cnt);
for (int i = 1; i <= n; i++)cin >> wB[i];
for (int i = 2; i <= n; i++)
{
cin >> x;
treeB[x].push_back(i);
}
cnt = 1;
dfsB(1, 0, cnt);
//按照dfs序对k集合的点进行升序排序
sort(ansA.begin(), ansA.end(), cmpA);
sort(ansB.begin(), ansB.end(), cmpB);
cnt = 0;
//我们以A树的k集合为标准枚举所有的点
for (int i = 0; i < m; i++)
{
//x记录k集合中dfs序最小的点,y记录最大的点
int xA = ansA[0], yA = ansA[m - 1], xB = ansB[0], yB = ansB[m - 1];
//如果最大的点被删除,我们取第二大的
if (i == 0)xA = ansA[1];
//如果最小的点被删除,我们取第二小的
else if (i == m - 1)yA = ansA[m - 2];
if (xB == ansA[i])xB = ansB[1];
else if (yB == ansA[i])yB = ansB[m - 2];
//求的两个树的最近公共祖先
int xlca = lcaA(xA, yA), ylca = lcaB(xB, yB);
if (wA[xlca] > wB[ylca])cnt++;
}
cout << cnt << endl;
}
signed main()
{
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int t = 1;
//cin >> t;
while (t--)
{
solve();
}
return 0;
}