一、前言
按大佬TOJOLINE的思路解题的,按输入index、本轮选择的数字num、剩余可递减次数rest、前index个数字的和total进行解题:
题解 | #美丽序列#_牛客博客 https://blog.nowcoder.net/n/6b0ed9b7374b4548ae814b8c499e6abc
最开始自己的递归思路是输入index、前index个数的和total、前一个数num1、前两个数num2,时间和空间很容易超.
另外值得指出的是,用pypy3会比python3快
二、暴力递归
def process(arr, N, K, index, num, rest, total):
# 如果能递归到index==N,那么算一次成功的取值(因为中途如果有不能走下去的,会直接返回res=0)
if index == N:
return 1
res = 0 # 用res记最终结果,并返回
# 基本思路是:process(arr, N, K, index, num, rest, total) = process(arr, N, K, index+1, num(从0-40), rest, total), rest和total跟着num变化
# 下面代码是细化rest和total跟着num变化做计算
if arr[index] == -1:
if index == 0:
# 如果第0个数是-1,num可以取0-40
for n in range(K+1):
res += process(arr, N, K, index+1, n, 2, total+n) # rest必是2
else:
# 如果是第1-N个数-1的话,num可以取【不超过avg = int(total / index)且不会超过递减次数的数】
for n in range(int(total / index) + 1):
rest_temp = rest
# 根据本轮即将取值n和上轮取值num的大小比较,决定rest_temp大小,只有rest_temp是1或2才能算一次成功的结果
if n < num:
rest_temp -= 1
else:
rest_temp = 2
if rest_temp != 0:
res += process(arr, N, K, index+1, n, rest_temp, total+n)
elif arr[index] != -1:
# 相比上面的arr[index] == -1,因为可以直接取arr[index],所以都少了for循环,其他代码不变
if index == 0:
res += process(arr, N, K, index+1, arr[index], 2, total+arr[index])
else:
rest_temp = rest
if arr[index] < num:
rest_temp -= 1
else:
rest_temp = 2
if rest_temp != 0:
if arr[index] <= total / index:
res += process(arr, N, K, index+1, arr[index], rest_temp, total+arr[index])
return res
# 这段代码通用的,后续代码将不再重复这段
def ans(arr):
arr = list(map(int, arr.split()))
N = len(arr)
K = 40 # arr[index]的取值范围
res = process(arr, N, K, 0, -2, 2, 0)
return res
n = int(input())
arr = input()
res = ans(arr)
print(res % int(1e9+7))
三、加入dp表加快速度
-
- dp表可以记录已经计算过的数,使到对一些情况不用重复计算,所以可以加速。
-
- dp的维度定义是这样的:dp[index][num][rest][total] 表示process(arr, N, K, index, num, rest, total)的值。
-
- 所以把(二)中的所有process(arr, N, K, index, num, rest, total, dp)先用dp[index][num][rest][total]记录,再返回
def process(arr, N, K, index, num, rest, total, dp):
if index == N:
return 1
# 注意这里加入了dp表,对于已经算过的维度组合,直接返回值,所以会加快
if dp[index][num][rest][total] != -2:
return dp[index][num][rest][total]
res = 0
if arr[index] == -1:
if index == 0:
for n in range(K+1):
res += process(arr, N, K, index+1, n, 2, total+n, dp)
dp[index][num][rest][total] = res
else:
for n in range(int(total / index) + 1):
rest_temp = rest
if n < num:
rest_temp -= 1
else:
rest_temp = 2
if rest_temp != 0:
res += process(arr, N, K, index+1, n, rest_temp, total+n, dp)
dp[index][num][rest][total] = res
elif arr[index] != -1:
if index == 0:
res += process(arr, N, K, index+1, arr[index], 2, total+arr[index], dp)
dp[index][num][rest][total] = res
else:
rest_temp = rest
if arr[index] < num:
rest_temp -= 1
else:
rest_temp = 2
if rest_temp != 0:
if arr[index] <= total / index:
res += process(arr, N, K, index+1, arr[index], rest_temp, total+arr[index], dp)
dp[index][num][rest][total] = res
return dp[index][num][rest][total]
def ans(arr):
arr = list(map(int, arr.split()))
N = len(arr)
K = 40 # arr[index]的取值范围
dp = [[[[-2 for _ in range(N * K + 1)] for _ in range(3)] for _ in range(K+1)] for _ in range(N+1)]
res = process(arr, N, K, 0, -2, 2, 0, dp)
return res
四、严格表结构——动态规划
去掉递归,直接用for循环填满dp表,最后取dp[0][0][2][0]就是结果。
这时候已经跟题意解耦,不需要再理会题意,只要找规律填表就可以。
4.1 直接去掉递归
- 先把核心代码复制
- 加入4层for循环,把上面(1)复制的代码贴入;
- 再直接用dp[index][num][rest][total]换掉所有process(arr, N, K, index, num, rest, total)就可以了;
- 再做一些小的bug修补。
def process(arr, N, K, dp):
# 4层for循环
for index in range(N, -1, -1):
for num in range(K+1):
for rest in range(1, 3):
for total in range(index*K+1):
# 跟递归一样,直接赋值,下面其他代码也是这样,可以跟递归的代码对比着看
if index == N:
dp[index][num][rest][total] = 1
continue
res = 0
if arr[index] == -1:
if index == 0:
# 内层这些for循环在下一步把他优化掉,这里很多可以不必重复计算
for n in range(K+1):
res += dp[index+1][n][2][total+n]
dp[index][num][rest][total] = res
else:
for n in range(int(total / index) + 1):
rest_temp = rest
if n < num:
rest_temp -= 1
else:
rest_temp = 2
if rest_temp != 0:
# print(index+1, n, rest_temp, total+n)
res += dp[index+1][n][rest_temp][total+n]
dp[index][num][rest][total] = res
elif arr[index] != -1:
if index == 0:
res += dp[index+1][arr[index]][2][total + arr[index]]
dp[index][num][rest][total] = res
else:
rest_temp = rest
if arr[index] < num:
rest_temp -= 1
else:
rest_temp = 2
if rest_temp != 0:
if arr[index] <= total / index:
res += dp[index+1][arr[index]][rest_temp][total+arr[index]]
dp[index][num][rest][total] = res
return dp[0][0][2][0]
这时候还是不够快,而最内层的一些for循环其实是没有必要重复算的,例如:
if arr[index] == -1:
if index == 0:
# 可以看到dp[index+1][n][2][total+n]只跟index和total有关,只要维度的index、total固定,不管num和rest怎样变,都相等
for n in range(K+1):
res += dp[index+1][n][2][total+n]
dp[index][num][rest][total] = res
else:
所以可以优化成:
if arr[index] == -1:
if index == 0:
if total == 0:
for n in range(K+1):
res += dp[index+1][n][2][total+n]
else:
# 只算index==0那次,其他的直接取值就行,取dp[index+1][0][1][total]、dp[index+1][0][2][total]、dp[index+1][1][1][total]...都可以
res = dp[index+1][0][1][total]
dp[index][num][rest][total] = res
其他的最内层也按这个思路优化:
def process(arr, N, K, dp):
for index in range(N, -1, -1):
for num in range(K+1):
for rest in range(1, 3):
for total in range(index*K+1):
if index == N:
dp[index][num][rest][total] = 1
continue
res = 0
if arr[index] == -1:
if index == 0:
if total == 0:
for n in range(K+1):
res += dp[index+1][n][2][total+n]
else:
res = dp[index+1][0][1][total]
dp[index][num][rest][total] = res
else:
if rest == 2:
if num <= int(total / index):
for n in range(0, num):
res += dp[index+1][n][1][total+n]
for n in range(num, int(total / index) + 1):
res += dp[index+1][n][2][total+n]
elif num > int(total / index):
if num == int(total / index)+1:
for n in range(int(total / index) + 1):
res += dp[index+1][n][1][total+n]
else:
res = dp[index][int(total / index)+1][rest][total]
# for n in range(int(total / index) + 1):
# res += dp[index+1][n][1][total+n]
# 优化了这儿的for循环
elif rest == 1:
if num == 0:
for n in range(num, int(total / index) + 1):
res += dp[index+1][n][2][total+n]
elif num <= int(total / index):
res = dp[index][num-1][rest][total] - dp[index+1][num-1][2][total+num-1]
# if num <= int(total / index):
# for n in range(num, int(total / index) + 1):
# res += dp[index+1][n][2][total+n]
dp[index][num][rest][total] = res
elif arr[index] != -1:
if index == 0:
res += dp[index+1][arr[index]][2][total + arr[index]]
dp[index][num][rest][total] = res
else:
rest_temp = rest
if arr[index] < num:
rest_temp -= 1
else:
rest_temp = 2
if rest_temp != 0:
if arr[index] <= total / index:
res += dp[index+1][arr[index]][rest_temp][total+arr[index]]
dp[index][num][rest][total] = res
return dp[0][0][2][0]
再进一步优化掉最后一个for循环:
def process(arr, N, K, dp):
for index in range(N, -1, -1):
for num in range(K+1):
for rest in range(1, 3):
for total in range(index*K+1):
if index == N:
dp[index][num][rest][total] = 1
continue
res = 0
if arr[index] == -1:
if index == 0:
if total == 0:
for n in range(K+1):
res += dp[index+1][n][2][total+n]
else:
res = dp[index+1][0][1][total]
dp[index][num][rest][total] = res
else:
if rest == 2:
# 优化了这儿
if num <= int(total / index):
if num == 0:
for n in range(0, num):
res += dp[index+1][n][1][total+n]
for n in range(num, int(total / index) + 1):
res += dp[index+1][n][2][total+n]
else:
res = dp[index][num-1][rest][total] + dp[index+1][num-1][1][total+num-1] - dp[index+1][num-1][2][total+num-1]
elif num > int(total / index):
if num == int(total / index)+1:
for n in range(int(total / index) + 1):
res += dp[index+1][n][1][total+n]
else:
res = dp[index][int(total / index)+1][rest][total]
elif rest == 1:
if num == 0:
for n in range(num, int(total / index) + 1):
res += dp[index+1][n][2][total+n]
elif num <= int(total / index):
res = dp[index][num-1][rest][total] - dp[index+1][num-1][2][total+num-1]
dp[index][num][rest][total] = res
elif arr[index] != -1:
if index == 0:
res += dp[index+1][arr[index]][2][total + arr[index]]
dp[index][num][rest][total] = res
else:
rest_temp = rest
if arr[index] < num:
rest_temp -= 1
else:
rest_temp = 2
if rest_temp != 0:
if arr[index] <= total / index:
res += dp[index+1][arr[index]][rest_temp][total+arr[index]]
dp[index][num][rest][total] = res
return dp[0][0][2][0]