零、前言

我们先用最暴力的方法进行解题,然后逐步优化成满足时间空间要求的解法。

一、传统递归法

1、核心流程

index和rest余数的定义:
举个例子:比如[1, 2, 5, 7, 9]数组,k=4;

  1. 依次从左到右遍历数组,指向1的时候,index=0, 如果要这个数,对k的余数rest=1;
  2. 然后index=1,指向2,如果要2,那么对k的余数rest = (1+2) % 4 = 3;
  3. index=2,指向5,如果要5,那么对k的余数rest = (3+5) % 4 = 0;
  4. index=3,指向7,如果要7,那么对k的余数rest = (0+7) % 4 = 3;
  5. index=4, 指向9,如果不要9,那么对k的余数rest = (3+0) % 4 = 3;
    ...

依次从左到右遍历数组,index表示当前数字的下标,对于这个数字,可以选择要,还是不要这个数字;
rest表示要或者不要这个数字后,已经决策过的数字中,这些数字对k的余数是多少。

依次递归,那么能取到的最大数,肯定是“要”或者“不要”这个数的两种情况下能取到的最大数中最大的那种情况,所以有以下伪代码:

def process_01(arr, N, K, index, rest):
    return max(process_01(arr, N, K, index+1, (rest + arr[index]) % K) + arr[index], # 要这个数的话,就累加上当前数arr[index]
               process_01(arr, N, K, index+1, rest))

2、增加边际条件,防止越界

def process_01(arr, N, K, index, rest):
    # 如果index到底了,余数rest=0,表示这条路行得通,rest!=0表示行不通,返回-1
    if index == N:
        if rest == 0:
            return 0
        if rest != 0:
            return -1

    return max(process_01(arr, N, K, index+1, (rest + arr[index]) % K) + arr[index],
               process_01(arr, N, K, index+1, rest))

3、增加状态判断

因为返回的-1是一个状态值,不能作为结果直接相加,所以要做一定的转换

def process_01(arr, N, K, index, rest):

    if index == N:
        if rest == 0:
            return 0
        if rest != 0:
            return -1

    # 单独把两种情况拎出来讨论
    p1 = process(arr, N, K, index+1, (rest+arr[index]) % K) # 要
    p2 = process(arr, N, K, index+1, rest) # 不要

    # 都是-1,两条路都走不通,也就是当前index无论取不取都没路走了,直接返回-1
    if p1 == -1 and p2 == -1:
        return -1
    else:
        # 只有p2走得通,就返回p2的值
        if p1 == -1 and p2 != -1:
            return p2
        # 只有p1走得通,因为p1是要取当前index的数的,所以加上arr[index]后返回
        elif p1 != -1 and p2 == -1:
            return p1 + arr[index]
        # 如果都走得通,就看哪个大就返回哪个
        elif p1 != -1 and p2 != -1:
            return max(p1 + arr[index], p2)
# 最后根据题意,修正一下返回值
def recur_answer_01(arr, N, K, index, rest):
    result = process_01(arr, N, K, index, rest)
    if result == 0:
        return -1
    else:
        return result

print(recur_answer_01(arr, N, K, 0, 0))

二、缓存递归的中间结果来加速

上面的结果因为存在很多重复计算,内存肯定不够;
因为有些结果,在某些步骤中已经算过了,为了加快速度,我们把已经算过的结果缓存在一个列表中,减少中间的一些重复计算;

因为process_01的输入参数arr, N, K, index, rest中,只有index和rest是可变参数,我们建立一个二维数组存储中间值。——表示当可变参数是index1和rest时,计算出来的最优值是多少。

# 建立dp记录中间值,-2表示当前位置还没算过
dp = [[-2 for _ in range(N+1)] for _ in range(K+1)] # K+1 * N+1: dp[rest][index]
# 加入记忆搜索dp的递归
# 将每个结果都记录在dp中,如果已经记录过了,就直接返回值,不再计算这个位置的数
def process(arr, N, K, index, rest, dp):
    if dp[rest][index] != -2:
        return dp[rest][index]

    if index == N:
        if rest == 0:
            dp[rest][index] = 0
            return dp[rest][index]
        if rest != 0:
            dp[rest][index] = -1
            return dp[rest][index]

    p1 = process(arr, N, K, index+1, (rest+arr[index]) % K, dp) # 要
    p2 = process(arr, N, K, index+1, rest, dp) # 不要

    # 将所有return换成dp[rest][index]来记录中间值,最后再一次性返回(当然你在记录后就返回也不是不行)
    if p1 == -1 and p2 == -1:
        dp[rest][index] = -1
    else:
        if p1 == -1 and p2 != -1:
            dp[rest][index] = p2
        elif p1 != -1 and p2 == -1:
            dp[rest][index] = p1 + arr[index]
        elif p1 != -1 and p2 != -1:
            dp[rest][index] = max(p1 + arr[index], p2)

    return dp[rest][index]

def recur_answer(arr, N, K, index, rest):
    dp = [[-2 for _ in range(N+1)] for _ in range(K+1)] # dp[rest][index]
    result = process(arr, N, K, index, rest, dp)
    if result == 0:
        return -1
    else:
        return result
print(recur_answer(arr, N, K, 0, 0)) # 实际上我们返回的是dp[0][0]的值

三、严格表结构找规律——动态规划

从(二)可以看到,只要我们能找到规律,把dp填满,再返回dp[0][0]的值就可以了,可以不理会题意了。——实际上不找规律,仍然有递归的话,在运行的时候栈会溢出
下面我们来找规律

1、base情况

注意到下面代码,表示的是第N列,除了第0行,其他都是-1,我们可以直接填上

def process(arr, N, K, index, rest, dp):
    ...
    if index == N:
        if rest == 0:
            dp[rest][index] = 0
            return dp[rest][index]
        if rest != 0:
            dp[rest][index] = -1
            return dp[rest][index]
    ...

图片说明

2、一般规律

注意到下面代码,当前dp[rest][index]只依赖于dp[(rest+arr[index]) % K][index+1]和dp[rest][index+1];
跟(二)的判断逻辑一样,p1和p2取较大的数就可以了

def process(arr, N, K, index, rest, dp):
    ... 
    p1 = process(arr, N, K, index+1, (rest+arr[index]) % K, dp) # 要
    p2 = process(arr, N, K, index+1, rest, dp) # 不要
    ...

图片说明

3、用循环把dp填满

这样,只要一个循环,我们把dp填满就可以了,下面我们按列取数进行填充

def dy_proaram(arr, N, K):
    dp = [[-2 for _ in range(N+1)] for _ in range(K+1)] # dp[rest][index]

    for index in range(N, -1, -1): # index
        for rest in range(K+1): # rest
            # 最后一列要么0,要么-1
            if index == N:
                if rest == 0:
                    dp[rest][index] = 0
                if rest != 0:
                    dp[rest][index] = -1
            else:
                # 把(二)中的p1和p2取数换一下,代码不用更改
                p1 = dp[(rest+arr[index]) % K][index+1]
                p2 = dp[rest][index+1]

                if p1 == -1 and p2 == -1:
                    dp[rest][index] = -1
                else:
                    if p1 == -1 and p2 != -1:
                        dp[rest][index] = p2
                    elif p1 != -1 and p2 == -1:
                        dp[rest][index] = p1 + arr[index]
                    elif p1 != -1 and p2 != -1:
                        dp[rest][index] = max(p1 + arr[index], p2)

    return dp[0][0]
def recur_dy_prog(arr, N, K):
    result = dy_proaram(arr, N, K)
    if result == 0:
        return -1
    else:
        return result
print(recur_dy_prog(arr, N, K))