字符串相乘 - LeetCode

导入依赖

主要依赖的库有:

  • math:用来进行幂运算。
  • random:用来生成随机测试用例
import math
import random

拆分、填充

  • 默认输入的格式为不固定长度的字符串,如 "123456"

  • 需要对输入的字符串拆分成长度为 的数字类型列表,如 [1,2,3,4,5,6]

  • 并对其进行填充,找到 的指数 ,满足 ,如:

    时,有

    满足

  • 使用 0 对列表进行填充后的长度满足: ,如 [1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0]
def to_list(num1:str, num2:str) -> tuple:
    # 拆分为 list
    a = [int(i) for i in num1]
    b = [int(i) for i in num2]

    # 反转列表,将低阶项系数放在列表前面
    a.reverse()
    b.reverse()
    max_len = max(len(a),len(b))

    # 对齐使长度相等
    l = len(a)-len(b)
    zeros = [0] * abs(l) 
    if l < 0:
        a = a + zeros
    elif l > 0:
        b = b + zeros

    # 补充前导 0,使得长度为 2^n
    fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len)
    fill = [0] * fill_count
    return a+fill,b+fill

多项式表示

对于输入的两个数 ,将其处理成两个多项式:

最终的目标是对多项式 进行求解。

傅里叶变换求解

  • 将处理后的两个列表进行快速傅里叶变换(fft),得到 个点值对的取值

  • 得到两个新的列表并将其按元素相乘,得到待求解的多项式 的值

  • 再进行逆离散傅里叶变换(idft),将点值表示转换为 的系数;

  • 傅里叶变换后的结果是虚数,其实部四舍五入后取整,便是结果多项式对应项的系数,将以 为底的多项式计算求和,得到乘法的结果。

def multiply(num1: str, num2: str) -> str:
    l = len(num1)
    a,b = to_list(num1, num2)
    # 傅里叶变换
    a_fft, b_fft = fft(a), fft(b)
    t = []
    # 对应项相乘
    for i in range(len(a_fft)):
        t.append(a_fft[i] * b_fft[i]) 
    # 逆傅里叶变换
    ans = idft(t)
    sum = 0
    # 计算多项式
    for i,r in enumerate(ans):
        # 实部四舍五入取整
        sum += int(r.real+0.5) * (10 ** i)
    return str(sum)

傅里叶变换实现

傅里叶变换与逆傅里叶变换的主要区别在于:逆傅里叶变换需要对计算的结果除以 (并不是在递归中进行),并且在计算的过程中

def _ft(l:list, idft = False):
    """
    基础的变换方法,通过变量控制进行dft还是idft

    :param bool idft: 控制进行傅里叶变换还是逆傅里叶变换
    """

    n = len(l)
    if n == 1:
        return l

    # dft 与 idft 分别处理 $\omega$
    o_n_e = -2j if idft else 2j
    o = 1
    o_n = math.e ** (o_n_e * math.pi / n)

    # 拆分奇偶项
    even_index = l[::2]
    odd_index = l[1::2]

    y_even = _ft(even_index, idft)
    y_odd = _ft(odd_index, idft)

    y = [0]*n
    for i in range(n//2):
        y[i] = y_even[i] + o * y_odd[i]
        y[i+n//2] = y_even[i] - o * y_odd[i]
        o *= o_n
    return y

def fft(l:list):
    """
    傅里叶变换
    """
    output = _ft(l)
    return output

def idft(l:list):
    """
    逆傅里叶变换
    """
    n = len(l)
    output = _ft(l,True)
    # 将计算的结果除以 $N$
    output = [i/n for i in output]
    return output

测试

multiply()方法输出的结果与自带的乘法计算结果进行比较,并输出测试结果。

def test(num1:str, num2:str):
    r = int(multiply(num1,num2))
    s = int(num1)*int(num2)
    t = 30
    print(f"{'-'*t} Test {'-'*t}")
    print(f"Test case: \n\t{num1} \n\t{num2}")
    print(f"Program output: \n\t{r}")
    print(f"Expected output: \n\t{s}")
    print(f"❌ FAILED" if r != s else "✔ OK")
    return r == s

编写测试用例

# 测试用例数
test_cases = 10
# 数据长度
INT_MAX = 1e100
for i in range(test_caese):
    num1 = str(random.randint(0, INT_MAX))
    num2 = str(random.randint(0, INT_MAX))
    test(num1, num2)

LeetCode AC 代码

class Solution:
    def to_list(self, num1:str, num2:str) -> tuple:
        a = [int(i) for i in num1]
        b = [int(i) for i in num2]
        a.reverse()
        b.reverse()
        l = len(a)-len(b)
        max_len = max(len(a),len(b))
        # 对齐使长度相等
        zeros = [0] * abs(l) 
        if l < 0:
            a = a + zeros
        elif l > 0:
            b = b + zeros
        # 补充前导 0,使得长度为 2^n
        fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len)
        fill = [0] * fill_count
        return a+fill,b+fill

    def multiply(self, num1: str, num2: str) -> str:
        l = len(num1)
        a,b = self.to_list(num1, num2)
        a_fft, b_fft = self.fft(a), self.fft(b)


        t = []
        _3 = []
        for i in range(len(a_fft)):
            t.append(a_fft[i] * b_fft[i]) 
        ans = self.idft(t)

        sum = 0
        for i,r in enumerate(ans):
            sum += int(r.real+0.5) * (10 ** i)
        return str(sum)



    def _ft(self, l:list, idft = False):
        n = len(l)
        if n == 1:
            return l
        o_n_e = -2j if idft else 2j
        even_index = l[::2]
        odd_index = l[1::2]
        o = 1
        o_n = math.e ** (o_n_e * math.pi / n)

        y_even = self._ft(even_index, idft)
        y_odd = self._ft(odd_index, idft)

        y = [0]*n
        for i in range(n//2):
            y[i] = y_even[i] + o * y_odd[i]
            y[i+n//2] = y_even[i] - o * y_odd[i]
            o *= o_n
        return y

    def fft(self, l:list):
        output = self._ft(l)
        return output

    def idft(self, l:list):
        n = len(l)
        output = self._ft(l,True)
        output = [i/n for i in output]
        return output