字符串相乘 - LeetCode
导入依赖
主要依赖的库有:
math:用来进行幂运算。random:用来生成随机测试用例
import math import random
拆分、填充
默认输入的格式为不固定长度的字符串,如
"123456"。需要对输入的字符串拆分成长度为
的数字类型列表,如
[1,2,3,4,5,6]。并对其进行填充,找到
的指数
,满足
,如:
时,有
,
满足
。
- 使用 
0对列表进行填充后的长度满足:
,如
[1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0]。 
def to_list(num1:str, num2:str) -> tuple:
    # 拆分为 list
    a = [int(i) for i in num1]
    b = [int(i) for i in num2]
    # 反转列表,将低阶项系数放在列表前面
    a.reverse()
    b.reverse()
    max_len = max(len(a),len(b))
    # 对齐使长度相等
    l = len(a)-len(b)
    zeros = [0] * abs(l) 
    if l < 0:
        a = a + zeros
    elif l > 0:
        b = b + zeros
    # 补充前导 0,使得长度为 2^n
    fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len)
    fill = [0] * fill_count
    return a+fill,b+fill 多项式表示
对于输入的两个数 、
 ,将其处理成两个多项式:
 
 
最终的目标是对多项式  进行求解。
傅里叶变换求解
将处理后的两个列表进行快速傅里叶变换(fft),得到
个点值对的取值
得到两个新的列表并将其按元素相乘,得到待求解的多项式
的值
再进行逆离散傅里叶变换(idft),将点值表示转换为
的系数;
傅里叶变换后的结果是虚数,其实部四舍五入后取整,便是结果多项式对应项的系数,将以
为底的多项式计算求和,得到乘法的结果。
def multiply(num1: str, num2: str) -> str:
    l = len(num1)
    a,b = to_list(num1, num2)
    # 傅里叶变换
    a_fft, b_fft = fft(a), fft(b)
    t = []
    # 对应项相乘
    for i in range(len(a_fft)):
        t.append(a_fft[i] * b_fft[i]) 
    # 逆傅里叶变换
    ans = idft(t)
    sum = 0
    # 计算多项式
    for i,r in enumerate(ans):
        # 实部四舍五入取整
        sum += int(r.real+0.5) * (10 ** i)
    return str(sum) 傅里叶变换实现
傅里叶变换与逆傅里叶变换的主要区别在于:逆傅里叶变换需要对计算的结果除以  (并不是在递归中进行),并且在计算的过程中 
。
def _ft(l:list, idft = False):
    """
    基础的变换方法,通过变量控制进行dft还是idft
    :param bool idft: 控制进行傅里叶变换还是逆傅里叶变换
    """
    n = len(l)
    if n == 1:
        return l
    # dft 与 idft 分别处理 $\omega$
    o_n_e = -2j if idft else 2j
    o = 1
    o_n = math.e ** (o_n_e * math.pi / n)
    # 拆分奇偶项
    even_index = l[::2]
    odd_index = l[1::2]
    y_even = _ft(even_index, idft)
    y_odd = _ft(odd_index, idft)
    y = [0]*n
    for i in range(n//2):
        y[i] = y_even[i] + o * y_odd[i]
        y[i+n//2] = y_even[i] - o * y_odd[i]
        o *= o_n
    return y
def fft(l:list):
    """
    傅里叶变换
    """
    output = _ft(l)
    return output
def idft(l:list):
    """
    逆傅里叶变换
    """
    n = len(l)
    output = _ft(l,True)
    # 将计算的结果除以 $N$
    output = [i/n for i in output]
    return output 测试
将multiply()方法输出的结果与自带的乘法计算结果进行比较,并输出测试结果。
def test(num1:str, num2:str):
    r = int(multiply(num1,num2))
    s = int(num1)*int(num2)
    t = 30
    print(f"{'-'*t} Test {'-'*t}")
    print(f"Test case: \n\t{num1} \n\t{num2}")
    print(f"Program output: \n\t{r}")
    print(f"Expected output: \n\t{s}")
    print(f"❌ FAILED" if r != s else "✔ OK")
    return r == s 编写测试用例
# 测试用例数
test_cases = 10
# 数据长度
INT_MAX = 1e100
for i in range(test_caese):
    num1 = str(random.randint(0, INT_MAX))
    num2 = str(random.randint(0, INT_MAX))
    test(num1, num2) LeetCode AC 代码
 class Solution:
    def to_list(self, num1:str, num2:str) -> tuple:
        a = [int(i) for i in num1]
        b = [int(i) for i in num2]
        a.reverse()
        b.reverse()
        l = len(a)-len(b)
        max_len = max(len(a),len(b))
        # 对齐使长度相等
        zeros = [0] * abs(l) 
        if l < 0:
            a = a + zeros
        elif l > 0:
            b = b + zeros
        # 补充前导 0,使得长度为 2^n
        fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len)
        fill = [0] * fill_count
        return a+fill,b+fill
    def multiply(self, num1: str, num2: str) -> str:
        l = len(num1)
        a,b = self.to_list(num1, num2)
        a_fft, b_fft = self.fft(a), self.fft(b)
        t = []
        _3 = []
        for i in range(len(a_fft)):
            t.append(a_fft[i] * b_fft[i]) 
        ans = self.idft(t)
        sum = 0
        for i,r in enumerate(ans):
            sum += int(r.real+0.5) * (10 ** i)
        return str(sum)
    def _ft(self, l:list, idft = False):
        n = len(l)
        if n == 1:
            return l
        o_n_e = -2j if idft else 2j
        even_index = l[::2]
        odd_index = l[1::2]
        o = 1
        o_n = math.e ** (o_n_e * math.pi / n)
        y_even = self._ft(even_index, idft)
        y_odd = self._ft(odd_index, idft)
        y = [0]*n
        for i in range(n//2):
            y[i] = y_even[i] + o * y_odd[i]
            y[i+n//2] = y_even[i] - o * y_odd[i]
            o *= o_n
        return y
    def fft(self, l:list):
        output = self._ft(l)
        return output
    def idft(self, l:list):
        n = len(l)
        output = self._ft(l,True)
        output = [i/n for i in output]
        return output 
京公网安备 11010502036488号