J - And And And
题目大意
给一棵树,然后算
公式的含义: 任意两个点之间的子路径异或和为0的数量和。
题解
自己做的时候比较呆,啥都看不懂。
公式的意思是先两个for枚举两个点,然后再两个for枚举两个子路径,如果子路径异或和是0 那就答案加一。
题解:先找子路径异或和是0的,然后两边节点乘一下就好。
因为是异或,可以处理一下每个点到根节点的异或。这样的话,如果两个点的值相同,那么这两个点的路径的异或和为0.
1、一个点是另一个儿子的时候,祖先节点的除了这条路上的儿子的数量乘儿子的子节点的数量。
2、两个节点跟lca拐了一下,然后就是 两个子节点的数量乘一下。
然后这样的话map存一下异或值的数量。
好难~
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <cmath>
#include <set>
#include <cstring>
#include <string>
#include <bitset>
#include <stdlib.h>
#include <time.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef unsigned long long ull;
typedef set<int>::iterator sit;
#define st first
#define sd second
#define mkp make_pair
#define pb push_back
void wenjian(){
freopen("concatenation.in","r",stdin);freopen("concatenation.out","w",stdout);}
void tempwj(){
freopen("hash.in","r",stdin);freopen("hash.out","w",stdout);}
ll gcd(ll a,ll b){
return b == 0 ? a : gcd(b,a % b);}
ll qpow(ll a,ll b,ll mod){
a %= mod;ll ans = 1;while(b){
if(b & 1)ans = ans * a;a = a * a;b >>= 1;}return ans;}
ll qpowm(ll a,ll b){
ll ans = 1;while(b){
if(b & 1)ans = ans * a;a = a * a;b >>= 1;}return ans;}
struct cmp{
bool operator()(const pii & a, const pii & b){
return a.st < b.st;}};
int lb(int x){
return x & -x;}
/* 结构体里写。。。 friend bool operator < (Node a,Node b) { return a.val > b.val; */
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 1000000007;
const int maxn = 1e5 + 100;
const int M = 1e5+5;
map<ll,ll> mm;
ll num[maxn];
ll a[maxn];
ll ans = 0;
int n;
vector<pll> vv[maxn];
void dfs(int x,int fa)
{
num[x] = 1;
for (int i = 0; i < vv[x].size(); i ++ )
{
int v = vv[x][i].st;
if(v == fa)
continue;
dfs(v,x);
num[x] += num[v];
}
}
void dfs1(int x,int fa, ll dep)
{
ans = (ans + mm[dep] * num[x] % mod) % mod;
for(int i =0 ; i < vv[x].size(); i ++ )
{
int v = vv[x][i].st;
if(v == fa)
continue;
mm[dep] += n - num[v];
dfs1(v,x,dep ^ vv[x][i].sd);
mm[dep] -= n - num[v];
}
}
void dfs2(int x,int fa,ll dep)
{
ans = (ans + mm[dep] * num[x] % mod) % mod;
for(int i = 0; i < vv[x].size(); i ++ )
{
int v = vv[x][i].st;
if(v == fa)
continue;
dfs2(v,x,dep ^ vv[x][i].sd);
}
mm[dep] = (mm[dep] + num[x]) % mod;
}
int main()
{
scanf("%d",&n);
for (int i = 2; i <= n; i ++ )
{
int x;
ll v;
scanf("%d%lld",&x,&v);
vv[x].pb(mkp(i,v));
vv[i].pb(mkp(x,v));
}
dfs(1,0);
dfs1(1,0,0);
mm.clear();
dfs2(1,0,0);
printf("%lld\n",ans);
}