题解:BISHI90 【模板】记忆化搜索

题目链接

递归函数记忆化

题目描述

定义函数

  • ,则
  • ,则
  • 否则,

给定若干组 (均不超过 ),输出 取模的结果。

解题思路

  • 该递归具有大量重叠子问题,且输入上界为 ,可用动态规划自底向上求解。
  • 两种安全实现:
    • C++ 可直接预处理
    • Java/Python 为节省内存,用按 的两层滚动数组:当前层用到同层的 与上一层的 ,可按顺序填表,并在到达某个 时回答对应查询。
  • 实现要点:
    • 预处理为
    • 按题意套用分支公式,所有加减在模 下进行。

代码

#include <bits/stdc++.h>
using namespace std;
static const int MOD = 1000000007;

inline int addmod(long long x, long long y) { x += y; x %= MOD; return (int)x; }
inline int submod(int x, int y) { x -= y; if (x < 0) x += MOD; return x; }

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    const int MAXA = 100, MAXB = 100, MAXC = 100;
    static int dp[MAXA + 1][MAXB + 1][MAXC + 1];
    for (int a = 0; a <= MAXA; ++a) {
        for (int b = 0; b <= MAXB; ++b) {
            for (int c = 0; c <= MAXC; ++c) {
                if (a == 0 || b == 0 || c == 0) { dp[a][b][c] = 1; }
                else if (a < b && b < c) {
                    int v = addmod(dp[a][b][c - 1], dp[a][b - 1][c - 1]);
                    dp[a][b][c] = submod(v, dp[a][b - 1][c]);
                } else {
                    int v1 = addmod(dp[a - 1][b][c], dp[a - 1][b - 1][c]);
                    int v2 = submod(dp[a - 1][b][c - 1], dp[a - 1][b - 1][c - 1]);
                    dp[a][b][c] = addmod((long long)v1, (long long)v2);
                }
            }
        }
    }

    int T; if (!(cin >> T)) return 0;
    while (T--) {
        int a, b, c; cin >> a >> b >> c;
        cout << dp[a][b][c] << '\n';
    }
    return 0;
}
import java.io.*;

public class Main {
    static final int MOD = 1000000007;

    static class FastScanner {
        private final InputStream in; private final byte[] buf = new byte[1<<16];
        private int p=0, l=0; FastScanner(InputStream is){in=is;}
        private int read() throws IOException { if (p>=l){ l=in.read(buf); p=0; if (l<=0) return -1;} return buf[p++]; }
        long nextLong() throws IOException { int c; long s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){x=x*10+(c-'0'); c=read();} return x*s; }
        int nextInt() throws IOException { return (int)nextLong(); }
    }

    static int add(int x,int y){ x+=y; if(x>=MOD) x-=MOD; return x; }
    static int sub(int x,int y){ x-=y; if(x<0) x+=MOD; return x; }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int T = fs.nextInt();
        int maxA = 100, maxB = 100, maxC = 100;
        // 预先分配滚动层
        int[][] prev = new int[maxB+1][maxC+1];
        int[][] cur = new int[maxB+1][maxC+1];
        // dp[a=0][*][*] = 1
        for (int b=0;b<=maxB;b++) for (int c=0;c<=maxC;c++) prev[b][c]=1;

        // 读入全部查询
        int[] aa = new int[T], bb = new int[T], cc = new int[T];
        int needA = 0;
        for (int i=0;i<T;i++){ aa[i]=fs.nextInt(); bb[i]=fs.nextInt(); cc[i]=fs.nextInt(); needA=Math.max(needA, aa[i]); }

        // answers
        int[] ans = new int[T];

        // 为每个 a 层预计算并回答该层的查询
        for (int a=1;a<=needA;a++){
            for (int b=0;b<=maxB;b++){
                for (int c=0;c<=maxC;c++){
                    if (b==0 || c==0) { cur[b][c]=1; continue; }
                    if (a < b && b < c) {
                        int v = add(cur[b][c-1], cur[b-1][c-1]);
                        cur[b][c] = sub(v, cur[b-1][c]);
                    } else {
                        int v1 = add(prev[b][c], prev[b-1][c]);
                        int v2 = sub(prev[b][c-1], prev[b-1][c-1]);
                        cur[b][c] = add(v1, v2);
                    }
                }
            }
            // 回答 a 层的查询
            for (int i=0;i<T;i++) if (aa[i]==a) ans[i]=cur[bb[i]][cc[i]];
            // 滚动
            int[][] tmp = prev; prev = cur; cur = tmp;
        }
        // 处理 a==0 的情况(本题 a>=1 不会用到),以及输出
        StringBuilder out = new StringBuilder();
        for (int i=0;i<T;i++) out.append(ans[i]).append('\n');
        System.out.print(out.toString());
    }
}
import sys

MOD = 10**9 + 7

data = sys.stdin.buffer.read().split()
it = iter(data)
T = int(next(it))
qs = []
max_a = 0
for _ in range(T):
    a = int(next(it)); b = int(next(it)); c = int(next(it))
    qs.append((a, b, c))
    if a > max_a: max_a = a

# 滚动按层计算
MAX = 100
prev = [[1]*(MAX+1) for _ in range(MAX+1)]  # a=0 层
cur = [[0]*(MAX+1) for _ in range(MAX+1)]

answers = [0]*T
for a in range(1, max_a+1):
    for b in range(0, MAX+1):
        for c in range(0, MAX+1):
            if b == 0 or c == 0:
                cur[b][c] = 1
            elif a < b and b < c:
                cur[b][c] = (cur[b][c-1] + cur[b-1][c-1] - cur[b-1][c]) % MOD
            else:
                cur[b][c] = (prev[b][c] + prev[b-1][c] + prev[b][c-1] - prev[b-1][c-1]) % MOD
    # 回答本层查询
    for i, (qa, qb, qc) in enumerate(qs):
        if qa == a:
            answers[i] = cur[qb][qc]
    prev, cur = cur, prev

sys.stdout.write("\n".join(str(x % MOD) for x in answers))

算法及复杂度

  • 算法:记忆化搜索(状态截断至 ),直接套用分支公式并取模
  • 时间复杂度:预处理/访问单个状态近似 ,整体
  • 空间复杂度:(常数大小的 状态)