使用字典树TrieTree将数列A中的每一个元素都按bit位存入,查找时从最高位的bit位与m的比特位进行比较,模拟异或的操作计算出子节点应该是哪一个,计算出的子节点索引比m的对应位大则统计,比m的对应bit位小则忽略,如果相等则递归查询下一个比特位。(参考了一个Java版本的题解)
#include <iostream> #include <vector> using namespace std; class TrieTree { public: int count; class TrieTree* next[2]; //子节点只有0/1; TrieTree() { count = 1; next[0] = NULL; next[1] = NULL; } }; //通过数列A构建字典树; TrieTree* buildTrieTree(const vector<int>& m_vec) { TrieTree* tree = new TrieTree(); for(const auto& a:m_vec) { TrieTree* current = tree; for(int i=31;i>=0;i--) { int digit = (a>>i) & 1; if(current->next[digit] == NULL) { current->next[digit] = new TrieTree(); }else{ current->next[digit]->count++; } current = current->next[digit]; } } return tree; } //查找tree中数a的异或结果; long search(TrieTree* tree, int a, int index, int m) { if(tree==NULL) return 0; TrieTree* current = tree; for(int i=index; i>=0;i--) { int a_bit = (a>>i) & 1; int m_bit = (m>>i) & 1; if(a_bit == 1 && m_bit == 1) { //子节点1:a_bit ^ 1 = 0 < m_bit; 忽略; //子节点0:a_bit ^ 0 = 1 == m_bit; 继续比较; if(current->next[0] == NULL) { return 0; } current = current->next[0]; }else if(a_bit == 0 && m_bit == 1) { //子节点1:a_bit ^ 1 = 1 == m_bit; 继续比较; if(current->next[1] == NULL) { return 0; } current = current->next[1]; //子节点0:a_bit ^ 0 = 0 < m_bit; 忽略; }else if(a_bit == 1 && m_bit == 0) { //子节点1:a_bit ^ 1 = 0 == m_bit; 继续比较; long p = search(current->next[1], a, i-1, m); //子节点0:a_bit ^ 0 = 1 > m_bit; 统计; long q = current->next[0] == NULL ? 0 : current->next[0]->count; return p+q; }else if(a_bit == 0 && m_bit == 0) { //子节点1:a_bit ^ 1 = 1 > m_bit; 统计; long q = current->next[1] == NULL ? 0 : current->next[1]->count; //子节点0:a_bit ^ 0 = 0 == m_bit; 继续比较; long p = search(current->next[0], a, i-1, m); return p+q; } } return 0; } //统计tree中数列A的异或结果; long solve(const vector<int>& m_vec, int m) { long ret = 0; TrieTree* tree = buildTrieTree(m_vec); for(const auto& a:m_vec) { //从最高位开始查找,最高位的索引为31; ret += search(tree,a,31,m); } //遍历m_vec时前后两个元素的异或都进行了统计,但异或的结果只有一个; return ret / 2; } int main() { int n,m; vector<int> m_vec; cin >> n >> m; while(n--) { int a; cin >> a; m_vec.push_back(a); } cout << solve(m_vec, m) << endl; return 0; }