使用字典树TrieTree将数列A中的每一个元素都按bit位存入,查找时从最高位的bit位与m的比特位进行比较,模拟异或的操作计算出子节点应该是哪一个,计算出的子节点索引比m的对应位大则统计,比m的对应bit位小则忽略,如果相等则递归查询下一个比特位。(参考了一个Java版本的题解)

#include <iostream>
#include <vector>
using namespace std;

class TrieTree
{
public:
    int count;
    class TrieTree* next[2];                //子节点只有0/1;
    TrieTree()
    {
        count = 1;
        next[0] = NULL;
        next[1] = NULL;
    }
};

//通过数列A构建字典树;
TrieTree* buildTrieTree(const vector<int>& m_vec)
{
    TrieTree* tree = new TrieTree();
    for(const auto& a:m_vec)
    {
        TrieTree* current = tree;
        for(int i=31;i>=0;i--)
        {
            int digit = (a>>i) & 1;
            if(current->next[digit] == NULL)
            {
                current->next[digit] = new TrieTree();
            }else{
                current->next[digit]->count++;
            }
            current = current->next[digit];
        }
    }
    return tree;
}

//查找tree中数a的异或结果;
long search(TrieTree* tree, int a, int index, int m)
{
    if(tree==NULL)
        return 0;

    TrieTree* current = tree;
    for(int i=index; i>=0;i--)
    {
        int a_bit = (a>>i) & 1;
        int m_bit = (m>>i) & 1;
        if(a_bit == 1 && m_bit == 1)
        {
            //子节点1:a_bit ^ 1 = 0 < m_bit; 忽略;
            //子节点0:a_bit ^ 0 = 1 == m_bit; 继续比较;
            if(current->next[0] == NULL)
            {
                return 0;
            }
            current = current->next[0];
        }else if(a_bit == 0 && m_bit == 1)
        {
            //子节点1:a_bit ^ 1 = 1 == m_bit; 继续比较;
            if(current->next[1] == NULL)
            {
                return 0;
            }
            current = current->next[1];
            //子节点0:a_bit ^ 0 = 0 < m_bit; 忽略;
        }else if(a_bit == 1 && m_bit == 0)
        {
            //子节点1:a_bit ^ 1 = 0 == m_bit; 继续比较;
            long p = search(current->next[1], a, i-1, m);
            //子节点0:a_bit ^ 0 = 1 > m_bit; 统计;
            long q = current->next[0] == NULL ? 0 : current->next[0]->count;
            return p+q;
        }else if(a_bit == 0 && m_bit == 0)
        {
            //子节点1:a_bit ^ 1 = 1 > m_bit; 统计;
            long q = current->next[1] == NULL ? 0 : current->next[1]->count;
            //子节点0:a_bit ^ 0 = 0 == m_bit; 继续比较;
            long p = search(current->next[0], a, i-1, m);
            return p+q;
        }
    }
    return 0;
}

//统计tree中数列A的异或结果;
long solve(const vector<int>& m_vec, int m)
{
    long ret = 0;
    TrieTree* tree = buildTrieTree(m_vec);
    for(const auto& a:m_vec)
    {
        //从最高位开始查找,最高位的索引为31;
        ret += search(tree,a,31,m);
    }

    //遍历m_vec时前后两个元素的异或都进行了统计,但异或的结果只有一个;
    return ret / 2;
}

int main()
{
    int n,m;
    vector<int> m_vec;
    cin >> n >> m;
    while(n--)
    {
        int a;
        cin >> a;
        m_vec.push_back(a);
    }

    cout << solve(m_vec, m) << endl;

    return 0;
}