比赛的时候想到是dsu on tree了,但是没想到如何去高效的维护答案,赛后听了大佬的做法后,恍然大悟。然而后面还是调了半天才调处来...看来我对树上启发式合并的理解还不够深刻。
新学了一个卡常小技巧,枚举子树的时候之前是用dfs去遍历的,但是其实可以预处理整棵树的dfs序,之后只要枚举dfs序就可以了。
#include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; #define X first #define Y second #define pb push_back #define pll pair<ll, ll> #define pli pair<ll, int> #define pii pair<int,int> #define New_Time srand((unsigned)time(NULL)) inline ll gcd(ll a, ll b) { while (b != 0) { ll c = a % b; a = b; b = c; }return a < 0 ? -a : a; } inline ll lowbit(ll x) { return x & (-x); } int head[2000010], Edge_Num; struct Edge { int to, next; ll w; }e[4000010]; inline void ade(int x, int y, ll w) { e[++Edge_Num] = { y,head[x],w }; head[x] = Edge_Num; } inline void G_init(int n) { memset(head, 0, sizeof(int) * (n + 100)); Edge_Num = 0; } int dir[8][2] = { {-1,0},{0,-1},{-1,-1},{1,-1},{1,0},{0,1},{1,1},{-1,1} }; const long double PI = 3.14159265358979323846; const int inf = 0x3f3f3f3f; const ll INF = 0x3f3f3f3f3f3f3f3f; inline ll rd() { ll x = 0; bool f = 1; char ch = getchar(); while (ch<'0' || ch>'9') { if (ch == '-')f = 0; ch = getchar(); } while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); } return f ? x : -x; } const double eps = 1e-8; const ll mod = 998244353; const int M = 1e6 + 10; const int N = 1e6 + 10; int fa[N], son[N], siz[N], d[N]; multiset<int>cnt[N]; void dfs(int x) { siz[x] = 1; for (int i = head[x]; i; i = e[i].next) { int y = e[i].to; if (y == fa[x])continue; d[y] = d[x] + 1; dfs(y); siz[x] += siz[y]; if (!son[x] || siz[son[x]] < siz[y])son[x] = y; } } int a[N], k, n; set<int>ext; inline bool ok(int x, int y) { auto it = cnt[x].lower_bound(y); if (it == cnt[x].end())return 0; return (*it) == y; } inline void del(int x, int y) { auto it = cnt[x].lower_bound(y); cnt[x].erase(it); } inline void ins(int x, int y) { cnt[x].insert(y); } void add(int x, int f) { if (f == 1)ins(d[x], a[x]); else del(d[x], a[x]); for (int i = head[x]; i; i = e[i].next) { int y = e[i].to; if (y == fa[x])continue; add(y, f); } } int ans[N]; void gao(int x) { if (ok(d[x], a[x] ^ k)) { if (ext.find(d[x]) == ext.end()) { //++ans[rt]; ext.insert(d[x]); } } ins(d[x], a[x]); for (int i = head[x]; i; i = e[i].next) { int y = e[i].to; if (y == fa[x])continue; gao(y); } } void dsu(int x, bool op) { for (int i = head[x]; i; i = e[i].next) { int y = e[i].to; if (y == fa[x] || y == son[x])continue; dsu(y, 0); } if (son[x]) dsu(son[x], 1); for (int i = head[x]; i; i = e[i].next) { int y = e[i].to; if (y == fa[x] || y == son[x])continue; gao(y); } ans[x] = ext.size(); ins(d[x], a[x]); if (!op)add(x, -1), ext.clear(); } void solve() { n = rd(), k = rd(); G_init(n); for (int i = 2; i <= n; i++) { fa[i] = rd(); ade(fa[i], i, 1); } for (int i = 1; i <= n; i++)a[i] = rd(); dfs(1); dsu(1, 0); ll m = 0; for (int i = 1; i <= n; i++) { m += i ^ (n - ans[i]); m %= mod; } cout << m << endl; } int main() { int _T = 1; // _T = rd(); while (_T--)solve(); }