字符串相乘 - LeetCode
导入依赖
主要依赖的库有:
math
:用来进行幂运算。random
:用来生成随机测试用例
import math import random
拆分、填充
默认输入的格式为不固定长度的字符串,如
"123456"
。需要对输入的字符串拆分成长度为
的数字类型列表,如
[1,2,3,4,5,6]
。并对其进行填充,找到
的指数
,满足
,如:
时,有
,
满足
。
- 使用
0
对列表进行填充后的长度满足:
,如
[1,2,3,4,5,6,0,0,0,0,0,0,0,0,0,0]
。
def to_list(num1:str, num2:str) -> tuple: # 拆分为 list a = [int(i) for i in num1] b = [int(i) for i in num2] # 反转列表,将低阶项系数放在列表前面 a.reverse() b.reverse() max_len = max(len(a),len(b)) # 对齐使长度相等 l = len(a)-len(b) zeros = [0] * abs(l) if l < 0: a = a + zeros elif l > 0: b = b + zeros # 补充前导 0,使得长度为 2^n fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len) fill = [0] * fill_count return a+fill,b+fill
多项式表示
对于输入的两个数 、
,将其处理成两个多项式:
最终的目标是对多项式 进行求解。
傅里叶变换求解
将处理后的两个列表进行快速傅里叶变换(fft),得到
个点值对的取值
得到两个新的列表并将其按元素相乘,得到待求解的多项式
的值
再进行逆离散傅里叶变换(idft),将点值表示转换为
的系数;
傅里叶变换后的结果是虚数,其实部四舍五入后取整,便是结果多项式对应项的系数,将以
为底的多项式计算求和,得到乘法的结果。
def multiply(num1: str, num2: str) -> str: l = len(num1) a,b = to_list(num1, num2) # 傅里叶变换 a_fft, b_fft = fft(a), fft(b) t = [] # 对应项相乘 for i in range(len(a_fft)): t.append(a_fft[i] * b_fft[i]) # 逆傅里叶变换 ans = idft(t) sum = 0 # 计算多项式 for i,r in enumerate(ans): # 实部四舍五入取整 sum += int(r.real+0.5) * (10 ** i) return str(sum)
傅里叶变换实现
傅里叶变换与逆傅里叶变换的主要区别在于:逆傅里叶变换需要对计算的结果除以 (并不是在递归中进行),并且在计算的过程中
。
def _ft(l:list, idft = False): """ 基础的变换方法,通过变量控制进行dft还是idft :param bool idft: 控制进行傅里叶变换还是逆傅里叶变换 """ n = len(l) if n == 1: return l # dft 与 idft 分别处理 $\omega$ o_n_e = -2j if idft else 2j o = 1 o_n = math.e ** (o_n_e * math.pi / n) # 拆分奇偶项 even_index = l[::2] odd_index = l[1::2] y_even = _ft(even_index, idft) y_odd = _ft(odd_index, idft) y = [0]*n for i in range(n//2): y[i] = y_even[i] + o * y_odd[i] y[i+n//2] = y_even[i] - o * y_odd[i] o *= o_n return y def fft(l:list): """ 傅里叶变换 """ output = _ft(l) return output def idft(l:list): """ 逆傅里叶变换 """ n = len(l) output = _ft(l,True) # 将计算的结果除以 $N$ output = [i/n for i in output] return output
测试
将multiply()
方法输出的结果与自带的乘法计算结果进行比较,并输出测试结果。
def test(num1:str, num2:str): r = int(multiply(num1,num2)) s = int(num1)*int(num2) t = 30 print(f"{'-'*t} Test {'-'*t}") print(f"Test case: \n\t{num1} \n\t{num2}") print(f"Program output: \n\t{r}") print(f"Expected output: \n\t{s}") print(f"❌ FAILED" if r != s else "✔ OK") return r == s
编写测试用例
# 测试用例数 test_cases = 10 # 数据长度 INT_MAX = 1e100 for i in range(test_caese): num1 = str(random.randint(0, INT_MAX)) num2 = str(random.randint(0, INT_MAX)) test(num1, num2)
LeetCode
AC 代码
class Solution: def to_list(self, num1:str, num2:str) -> tuple: a = [int(i) for i in num1] b = [int(i) for i in num2] a.reverse() b.reverse() l = len(a)-len(b) max_len = max(len(a),len(b)) # 对齐使长度相等 zeros = [0] * abs(l) if l < 0: a = a + zeros elif l > 0: b = b + zeros # 补充前导 0,使得长度为 2^n fill_count = int(2**math.ceil(math.log2(max_len*2)) - max_len) fill = [0] * fill_count return a+fill,b+fill def multiply(self, num1: str, num2: str) -> str: l = len(num1) a,b = self.to_list(num1, num2) a_fft, b_fft = self.fft(a), self.fft(b) t = [] _3 = [] for i in range(len(a_fft)): t.append(a_fft[i] * b_fft[i]) ans = self.idft(t) sum = 0 for i,r in enumerate(ans): sum += int(r.real+0.5) * (10 ** i) return str(sum) def _ft(self, l:list, idft = False): n = len(l) if n == 1: return l o_n_e = -2j if idft else 2j even_index = l[::2] odd_index = l[1::2] o = 1 o_n = math.e ** (o_n_e * math.pi / n) y_even = self._ft(even_index, idft) y_odd = self._ft(odd_index, idft) y = [0]*n for i in range(n//2): y[i] = y_even[i] + o * y_odd[i] y[i+n//2] = y_even[i] - o * y_odd[i] o *= o_n return y def fft(self, l:list): output = self._ft(l) return output def idft(self, l:list): n = len(l) output = self._ft(l,True) output = [i/n for i in output] return output