O(n)伤不起:模拟二进制运算,绝对不涉及乘法,复杂度为O(logN)

2进制乘法原理与10进制类似,但是2进制更加简单,因为2进制非0即1,所以乘数m的一个二进制位与被乘数n相乘的结果要么是0要么是n本身,只要在实际计算过程中根据m的某个进制位所在的位置对n进行移位就可以了。

当然,这里讨论的都是正数乘法,负数乘法涉及补码,可以转为正数后再运算,不过本题不涉及负数;

所以整个过程只涉及位操作(位与、位移)和加法运算。

图片说明

  1. 在编写代码时,首先需要取得乘数m的某一个进制位:
    假设有变量 bitMask = 1,那么要取得 m 的第 k 位(从低位开始)二进制位的表达式就是: m & (bitMask << k);
    当然还有其他写法,这里使用的是:(m >> k) & 1,即将该位位移到最低位,然后和1相与,屏蔽掉高位。
  2. 其次需要根据该进制位的值对结果进行累加,如果值为0,则加0,如果值为1,则加上 (n << k):
    因为题目不允许使用条件判断,所以这里还是通过位与运算来实现:num & 0x00000000 == 0,num & 0xFFFFFFFF == num
    然后使用映射的思想,构建一个数组 mask,把0、1分别映射为0x00000000、0xFFFFFFFF
    所以累加的表达式就为:result += (n << k) & mask[(m >> k) & 1];

改进:
一开始并没有想到可以使用“短路求值原理”来做为递归的结束条件,所以傻傻地选择了顺序写上16次相似的语句来替换循环,后来看到了用户Bine的代码,才改写成了最终的递归形式,这时移位操作就可以写在参数里了,而不用写在表达式里,详细代码如下:

复杂度分析:
虽然用到了递归,但是递归的执行次数最多为一个数二进制形式的长度,显然,整数n的二进制长度为log(n)

public class Solution {
    int[] mask = {0x00000000, 0xFFFFFFFF};

    public int Sum_Solution(int n) {
        return positiveProduction(n+1, n) >> 1;
    }

    // 注意,本函数仅适用于正数的乘法,负数需额外处理
    int positiveProduction(int m, int n) {
        int result = 0;
        boolean isStop = (m != 0) && 
               (result = (n & mask[m & 1]) + positiveProduction(m >> 1, n << 1)) != 0;
        return result;
    }
}

下面是最初的版本,看起来有点傻:

public class Solution {
    public int Sum_Solution(int n) {
        int m = n + 1;
        int[] mask = {0x00000000, 0xFFFFFFFF};
        int answer = 0;
        answer += (n << 0) & mask[(m >> 0) & 1];
        answer += (n << 1) & mask[(m >> 1) & 1];
        answer += (n << 2) & mask[(m >> 2) & 1];
        answer += (n << 3) & mask[(m >> 3) & 1];
        answer += (n << 4) & mask[(m >> 4) & 1];
        answer += (n << 5) & mask[(m >> 5) & 1];
        answer += (n << 6) & mask[(m >> 6) & 1];
        answer += (n << 7) & mask[(m >> 7) & 1];
        answer += (n << 8) & mask[(m >> 8) & 1];
        answer += (n << 9) & mask[(m >> 9) & 1];
        answer += (n << 10) & mask[(m >> 10) & 1];
        answer += (n << 11) & mask[(m >> 11) & 1];
        answer += (n << 12) & mask[(m >> 12) & 1];
        answer += (n << 13) & mask[(m >> 13) & 1];
        answer += (n << 14) & mask[(m >> 14) & 1];
        answer += (n << 15) & mask[(m >> 15) & 1];
        answer += (n << 16) & mask[(m >> 16) & 1];
        return answer >> 1;
    }
}