J - And And And

题目链接

题目大意

给一棵树,然后算
公式的含义: 任意两个点之间的子路径异或和为0的数量和。

题解

自己做的时候比较呆,啥都看不懂。
公式的意思是先两个for枚举两个点,然后再两个for枚举两个子路径,如果子路径异或和是0 那就答案加一。
题解:先找子路径异或和是0的,然后两边节点乘一下就好。
因为是异或,可以处理一下每个点到根节点的异或。这样的话,如果两个点的值相同,那么这两个点的路径的异或和为0.
1、一个点是另一个儿子的时候,祖先节点的除了这条路上的儿子的数量乘儿子的子节点的数量。
2、两个节点跟lca拐了一下,然后就是 两个子节点的数量乘一下。
然后这样的话map存一下异或值的数量。
好难~

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <cmath>
#include <set>
#include <cstring>
#include <string>
#include <bitset>
#include <stdlib.h>
#include <time.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef unsigned long long ull;
typedef set<int>::iterator sit;
#define st first
#define sd second
#define mkp make_pair
#define pb push_back
void wenjian(){
   freopen("concatenation.in","r",stdin);freopen("concatenation.out","w",stdout);}
void tempwj(){
   freopen("hash.in","r",stdin);freopen("hash.out","w",stdout);}
ll gcd(ll a,ll b){
   return b == 0 ? a : gcd(b,a % b);}
ll qpow(ll a,ll b,ll mod){
   a %= mod;ll ans = 1;while(b){
   if(b & 1)ans = ans * a;a = a * a;b >>= 1;}return ans;}
ll qpowm(ll a,ll b){
   ll ans = 1;while(b){
   if(b & 1)ans = ans * a;a = a * a;b >>= 1;}return ans;}

struct cmp{
   bool operator()(const pii & a, const pii & b){
   return a.st < b.st;}};
int lb(int x){
   return  x & -x;}
/* 结构体里写。。。 friend bool operator < (Node a,Node b) { return a.val > b.val; */
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 1000000007;
const int maxn = 1e5 + 100;
const int M = 1e5+5;
map<ll,ll> mm;
ll num[maxn];
ll a[maxn];
ll ans = 0;
int n;
vector<pll> vv[maxn];
void dfs(int x,int fa)
{
   
    num[x] = 1;
    for (int i = 0; i < vv[x].size(); i ++ )
    {
   
        int v = vv[x][i].st;
        if(v == fa)
            continue;
        dfs(v,x);
        num[x] += num[v];
    }
}
void dfs1(int x,int fa, ll dep)
{
   
    ans = (ans + mm[dep] * num[x] % mod) % mod;
    for(int i =0 ; i < vv[x].size(); i ++ )
    {
   
        int v = vv[x][i].st;
        if(v == fa)
            continue;
        mm[dep] += n - num[v];
        dfs1(v,x,dep ^ vv[x][i].sd);
        mm[dep] -= n - num[v];
    }
}
void dfs2(int x,int fa,ll dep)
{
   
    ans = (ans + mm[dep] * num[x] % mod) % mod;
    for(int i = 0; i < vv[x].size(); i ++ )
    {
   
        int v = vv[x][i].st;
        if(v == fa)
            continue;
        dfs2(v,x,dep ^ vv[x][i].sd);
    }
    mm[dep] = (mm[dep] + num[x]) % mod;
}

int main()
{
   
    scanf("%d",&n);
    for (int i = 2; i <= n; i ++ )
    {
   
        int x;
        ll v;
        scanf("%d%lld",&x,&v);
        vv[x].pb(mkp(i,v));
        vv[i].pb(mkp(x,v));
    }
    dfs(1,0);
    dfs1(1,0,0);
    mm.clear();
    dfs2(1,0,0);
    printf("%lld\n",ans);

}