read it on blog

P5903 【模板】树上 k 级祖先

题意

给定一棵树,我们要求 x 的上 k 级祖先。

思路

由于该题为长链剖分的模板题,因此重要介绍一下,长链剖分。

树

对于一棵树,我们的长链剖分和重链剖分,我们对每个链取最长的那条链作为我们的重链。如上图所示,我们可以把这棵树拆成一个一个链。为了方便画图,我把这颗树位置变换了一下。

剖分后的树

如上图所示,我们重构后的树就如其所示,我们不难发现这棵树有一个特点。对于所有的链而言,如果我将重链继承下来的话,那么我们通过不断合并,那么我们可以做到 O(N)O(N) 的时间复杂度将我们的答案统计完,因此,对于所有可以顺序继承的题目而言,这是非常好的性质。

还有一个性质,我们的任意一个链上跳的次数不会超过 N\sqrt N 次,我们可以通过这个性质完成一些复杂度要求较低的题目。

最后,对于我们需要完成的这个题目而言,还有一个性质,任何一个点的 kk 级祖先,其所在的重链长度一定大于等于 kk

那么了解完这些性质,我们来看我们这个题目该如何完成。

首先,我们预处理每个点重链的根节点。

然后我们对于 kk 次祖先,我们取该点 2h2^h 的祖先,其中 hh 满足 2hk<2h+12^h \le k < 2^{h+1}

我们可以通过预处理得到 [1,n][1 , n] 中的所有 hh

随后,我们发现,我们所在的点里目的地只需要 k2hk - 2^h 步了。

由于上述的最后一个性质,我们可以发现其 2h 2 ^h 次祖先,其链的长度一定大于等于 2h2^h 因此,我们只需要记录,以该重链的根为中心,其往重链所包含所有结点,以及其往上走和重链长度相等个数的父节点,我们即可求出该点。

#include <iostream>
#include <vector>

using namespace std;

#define ui unsigned int
ui s;

inline ui get(ui x) {
	x ^= x << 13;
	x ^= x >> 17;
	x ^= x << 5;
	return s = x;
}

const int N = 5e5 + 10;
vector<int> E[N];

int fa[N][22];
int son[N];
int buff[N * 4];
int * pos = buff;
int * F[N];
int _2[N];

int dep[N];
int len[N];
int ffa[N];

void get_son(int u,int deep = 1) {
    dep[u] = deep;
    for (auto v : E[u]) {
        get_son(v , deep + 1);
        if(len[v] > len[u]) {
            len[u] = len[v];
            son[u] = v;
        }
    }
    len[u] ++ ;
}

void get_F(int u) {
    if(ffa[u] == 0) {
        ffa[u] = u;
        pos += len[u] + 1;
        F[u] = pos;
        pos += len[u] + 1;
        int now = u;
        for (int i = 0 ; i < len[u] ; i ++ ) {
            F[u][-i] = now;
            now = fa[now][0];
            if(now == 0) break;
        }
    }

    F[ffa[u]][dep[u] - dep[ffa[u]]] = u;

    if(son[u]) {
        ffa[son[u]] = ffa[u];
        get_F(son[u]);
    }
    for (auto v : E[u]) {
        if(v == son[u]) continue;
        get_F(v);
    }
}

int querry(int x,int k) {
    if(k == 0) return x;
    int y = fa[x][_2[k]];
    y = ffa[y];
    return F[y][dep[x] - dep[y] - k];
}

int main() {
    int n, q; cin >> n >> q >> s;
    int root;
    for (int i = 1 ; i <= n ; i ++ ) {
        int x ; cin >> x;
        if(x == 0) root = i;
        else E[x].push_back(i) , fa[i][0] = x;
    }

    for (int i = 1 , kki = 0 ; i <= n ; i ++ ) {
        while( (1 << (kki + 1)) <= i) kki ++ ;
        _2[i] = kki;
    }

    for (int j = 1 ; j < 22 ; j ++ ) {
        for (int i = 1 ; i <= n ; i ++ ) {
            fa[i][j] = fa[fa[i][j - 1]][j - 1];
        }
    } 

    get_son(root);
    get_F(root);
    long long ans = 0;
    long long pre = 0;
    int i = 0;
    for (int i = 1 ; i <= q ; i ++ ) {
        int x = (get(s) ^ pre) % n + 1;
        int k = (get(s) ^ pre) % dep[x];
        pre = querry(x , k);
#ifdef DEBUG
        cerr << "X = " << x <<  " K = " << k 
             << " ans = " << pre << '\n';
#endif
        ans ^= i * pre;
    }
    
    cout << ans << '\n';

}