使用字典树TrieTree将数列A中的每一个元素都按bit位存入,查找时从最高位的bit位与m的比特位进行比较,模拟异或的操作计算出子节点应该是哪一个,计算出的子节点索引比m的对应位大则统计,比m的对应bit位小则忽略,如果相等则递归查询下一个比特位。(参考了一个Java版本的题解)
#include <iostream>
#include <vector>
using namespace std;
class TrieTree
{
public:
int count;
class TrieTree* next[2]; //子节点只有0/1;
TrieTree()
{
count = 1;
next[0] = NULL;
next[1] = NULL;
}
};
//通过数列A构建字典树;
TrieTree* buildTrieTree(const vector<int>& m_vec)
{
TrieTree* tree = new TrieTree();
for(const auto& a:m_vec)
{
TrieTree* current = tree;
for(int i=31;i>=0;i--)
{
int digit = (a>>i) & 1;
if(current->next[digit] == NULL)
{
current->next[digit] = new TrieTree();
}else{
current->next[digit]->count++;
}
current = current->next[digit];
}
}
return tree;
}
//查找tree中数a的异或结果;
long search(TrieTree* tree, int a, int index, int m)
{
if(tree==NULL)
return 0;
TrieTree* current = tree;
for(int i=index; i>=0;i--)
{
int a_bit = (a>>i) & 1;
int m_bit = (m>>i) & 1;
if(a_bit == 1 && m_bit == 1)
{
//子节点1:a_bit ^ 1 = 0 < m_bit; 忽略;
//子节点0:a_bit ^ 0 = 1 == m_bit; 继续比较;
if(current->next[0] == NULL)
{
return 0;
}
current = current->next[0];
}else if(a_bit == 0 && m_bit == 1)
{
//子节点1:a_bit ^ 1 = 1 == m_bit; 继续比较;
if(current->next[1] == NULL)
{
return 0;
}
current = current->next[1];
//子节点0:a_bit ^ 0 = 0 < m_bit; 忽略;
}else if(a_bit == 1 && m_bit == 0)
{
//子节点1:a_bit ^ 1 = 0 == m_bit; 继续比较;
long p = search(current->next[1], a, i-1, m);
//子节点0:a_bit ^ 0 = 1 > m_bit; 统计;
long q = current->next[0] == NULL ? 0 : current->next[0]->count;
return p+q;
}else if(a_bit == 0 && m_bit == 0)
{
//子节点1:a_bit ^ 1 = 1 > m_bit; 统计;
long q = current->next[1] == NULL ? 0 : current->next[1]->count;
//子节点0:a_bit ^ 0 = 0 == m_bit; 继续比较;
long p = search(current->next[0], a, i-1, m);
return p+q;
}
}
return 0;
}
//统计tree中数列A的异或结果;
long solve(const vector<int>& m_vec, int m)
{
long ret = 0;
TrieTree* tree = buildTrieTree(m_vec);
for(const auto& a:m_vec)
{
//从最高位开始查找,最高位的索引为31;
ret += search(tree,a,31,m);
}
//遍历m_vec时前后两个元素的异或都进行了统计,但异或的结果只有一个;
return ret / 2;
}
int main()
{
int n,m;
vector<int> m_vec;
cin >> n >> m;
while(n--)
{
int a;
cin >> a;
m_vec.push_back(a);
}
cout << solve(m_vec, m) << endl;
return 0;
}
京公网安备 11010502036488号