解题思路

  1. 状态定义:

    • 表示当前剩余区间 能获得的最大价值
    • 当前是第 次取数
  2. 转移方程:

  3. 边界条件:

    • 时,

代码

#include <bits/stdc++.h>
using namespace std;

const int MAXN = 1005;
int A[MAXN], B[MAXN];
long long dp[MAXN][MAXN];

int main() {
    int n;
    cin >> n;
    
    // 读入数列A和B
    for(int i = 0; i < n; i++) cin >> A[i];
    for(int i = 0; i < n; i++) cin >> B[i];
    
    // 初始化dp数组
    memset(dp, 0, sizeof(dp));
    
    // 处理长度为1的情况
    for(int i = 0; i < n; i++) {
        dp[i][i] = 1LL * A[i] * B[n-1];
    }
    
    // 区间dp
    for(int len = 2; len <= n; len++) {
        for(int i = 0; i + len - 1 < n; i++) {
            int j = i + len - 1;
            int turn = n - len;  // 当前是第几次取数
            // 取左端
            dp[i][j] = max(dp[i][j], dp[i+1][j] + 1LL * A[i] * B[turn]);
            // 取右端
            dp[i][j] = max(dp[i][j], dp[i][j-1] + 1LL * A[j] * B[turn]);
        }
    }
    
    cout << dp[0][n-1] << endl;
    return 0;
}

import java.util.*;

public class Main {
    static final int MAXN = 1005;
    static int[] A = new int[MAXN];
    static int[] B = new int[MAXN];
    static long[][] dp = new long[MAXN][MAXN];
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        
        // 读入数列A和B
        for(int i = 0; i < n; i++) A[i] = sc.nextInt();
        for(int i = 0; i < n; i++) B[i] = sc.nextInt();
        
        // 处理长度为1的情况
        for(int i = 0; i < n; i++) {
            dp[i][i] = 1L * A[i] * B[n-1];
        }
        
        // 区间dp
        for(int len = 2; len <= n; len++) {
            for(int i = 0; i + len - 1 < n; i++) {
                int j = i + len - 1;
                int turn = n - len;  // 当前是第几次取数
                // 取左端
                dp[i][j] = Math.max(dp[i][j], 
                                  dp[i+1][j] + 1L * A[i] * B[turn]);
                // 取右端
                dp[i][j] = Math.max(dp[i][j], 
                                  dp[i][j-1] + 1L * A[j] * B[turn]);
            }
        }
        
        System.out.println(dp[0][n-1]);
        sc.close();
    }
}
def solve(n: int, A: list, B: list) -> int:
    # 初始化dp数组
    dp = [[0] * n for _ in range(n)]
    
    # 处理长度为1的情况
    for i in range(n):
        dp[i][i] = A[i] * B[n-1]
    
    # 区间dp
    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            turn = n - length  # 当前是第几次取数
            # 取左端
            dp[i][j] = max(dp[i][j], 
                          dp[i+1][j] + A[i] * B[turn])
            # 取右端
            dp[i][j] = max(dp[i][j], 
                          dp[i][j-1] + A[j] * B[turn])
    
    return dp[0][n-1]

def main():
    n = int(input())
    A = list(map(int, input().split()))
    B = list(map(int, input().split()))
    print(solve(n, A, B))

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:区间动态规划
  • 时间复杂度:,其中 是数列长度
  • 空间复杂度:,用于存储 数组