题目链接

PEEK153 渝

题目描述

给定一个由数字 组成的三角形,一共有 行,第 行有 个数字。我们使用 表示第 行第 列的数字。

从顶点 出发,每次可以移动到左下方 或右下方 ,直到到达第 行。这样形成一条长度为 的路径,路径经过的数字拼接成一个字符串。

请求出有多少条路径,其对应的字符串是回文串。答案需要对 取模。

解题思路

本题要求我们统计满足回文条件的路径数量。回文串的性质是 s[i] == s[n-i+1],这意味着路径上第 行的数字必须与第 行的数字相同。这个首尾对应的要求,是使用“双端动态规划”或“相遇 DP”(Meet-in-the-middle DP)的典型特征。

1. 核心思想

我们可以想象有两条路径同时“生长”:一条从顶端 向下,另一条从底端(第 行)向上。当它们在中间相遇时,如果满足所有回文条件,就构成了一条合法的路径。

更具体地说,我们同步处理第 行和第 行。我们定义一个 DP 状态来记录满足回文条件的、“上半部分”和“下半部分”路径片段的数量。

2. 状态定义

我们定义 为:

  • 上半部分路径从 走了 步到达第 行的第 列。
  • 下半部分路径(反向)从第 行的某个位置出发,走了 步到达第 行的第 列。
  • 并且,对于所有 ,都满足上半部分路径在第 行的数字等于下半部分路径在第 行的数字。
  • 的值就是满足上述所有条件的路径片段对的数量。

3. 状态转移

我们从 开始,逐步计算到 。 要计算 ,我们需要考虑 的状态。

  • 回文条件: 首先,必须满足当前层的回文条件,即 。如果不满足,则
  • 路径转移:
    • 上半部分路径要到达 ,它必须从 走来。所以 可以是
    • 下半部分路径(反向)要到达 ,它必须从 走来。所以 可以是

综合起来,状态转移方程为:

(所有计算需在边界范围内并对 取模)

4. 初始状态

: 上半部分在 ,下半部分在

如果

如果

对于所有

5. 最终答案

计算进行到 步后,我们需要根据 的奇偶性来合并路径。

  • 如果 是偶数 (n = 2H):

    上半部分路径到达第 行,下半部分路径到达第 行。为了连接成一条完整路径,上半路径在 的下一步必须是下半路径的起点

    这要求

    总方案数 =

  • 如果 是奇数 (n = 2H+1):

    上半部分路径到达第 行,下半部分路径到达第 行。它们需要通过中间的第 行连接。 路径形如:

    对于一个已知的 对,连接方式有:

    • 如果 :只有 1 种连接方式。
    • 如果 :有 2 种连接方式(经过 )。
    • 其他情况:0 种。

    总方案数 =

代码

#include <iostream>
#include <vector>
#include <numeric>

using namespace std;

const int MOD = 1e9 + 7;

int main() {
    int n;
    cin >> n;

    vector<vector<int>> a(n + 1, vector<int>(n + 1));
    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= i; ++j) {
            cin >> a[i][j];
        }
    }

    if (n == 1) {
        cout << 1 << endl;
        return 0;
    }

    vector<vector<long long>> dp(n + 2, vector<long long>(n + 2, 0));

    // Base case: i = 1
    for (int c2 = 1; c2 <= n; ++c2) {
        if (a[1][1] == a[n][c2]) {
            dp[1][c2] = 1;
        }
    }

    int half = n / 2;
    for (int i = 2; i <= half; ++i) {
        vector<vector<long long>> next_dp(n + 2, vector<long long>(n + 2, 0));
        for (int c1 = 1; c1 <= i; ++c1) {
            for (int c2 = 1; c2 <= n - i + 1; ++c2) {
                if (a[i][c1] == a[n - i + 1][c2]) {
                    long long count = 0;
                    // From (i-1, c1-1)
                    count = (count + dp[c1 - 1][c2]) % MOD;
                    count = (count + dp[c1 - 1][c2 + 1]) % MOD;
                    // From (i-1, c1)
                    count = (count + dp[c1][c2]) % MOD;
                    count = (count + dp[c1][c2 + 1]) % MOD;
                    next_dp[c1][c2] = count;
                }
            }
        }
        dp = next_dp;
    }

    long long ans = 0;
    if (n % 2 == 0) { // n is even
        for (int c1 = 1; c1 <= half; ++c1) {
            ans = (ans + dp[c1][c1]) % MOD;
            ans = (ans + dp[c1][c1 + 1]) % MOD;
        }
    } else { // n is odd
        vector<vector<long long>> next_dp(n + 2, vector<long long>(n + 2, 0));
        int mid_row = half + 1;
        for (int c1 = 1; c1 <= mid_row; ++c1) {
            for (int c2 = 1; c2 <= n - mid_row + 1; ++c2) {
                 if (a[mid_row][c1] == a[n - mid_row + 1][c2]) { // This check is actually not needed for odd n
                    long long count = 0;
                    count = (count + dp[c1 - 1][c2]) % MOD;
                    count = (count + dp[c1 - 1][c2 + 1]) % MOD;
                    count = (count + dp[c1][c2]) % MOD;
                    count = (count + dp[c1][c2 + 1]) % MOD;
                    if (c1 == c2) {
                       ans = (ans + count) % MOD;
                    }
                }
            }
        }
    }
    cout << ans << endl;

    return 0;
}
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[][] a = new int[n + 1][n + 1];
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= i; j++) {
                a[i][j] = sc.nextInt();
            }
        }

        if (n == 1) {
            System.out.println(1);
            return;
        }

        long[][] dp = new long[n + 2][n + 2];
        final int MOD = 1_000_000_007;

        // Base case: i = 1
        for (int c2 = 1; c2 <= n; c2++) {
            if (a[1][1] == a[n][c2]) {
                dp[1][c2] = 1;
            }
        }

        int half = n / 2;
        for (int i = 2; i <= half; i++) {
            long[][] nextDp = new long[n + 2][n + 2];
            for (int c1 = 1; c1 <= i; c1++) {
                for (int c2 = 1; c2 <= n - i + 1; c2++) {
                    if (a[i][c1] == a[n - i + 1][c2]) {
                        long count = 0;
                        count = (count + dp[c1 - 1][c2]) % MOD;
                        count = (count + dp[c1 - 1][c2 + 1]) % MOD;
                        count = (count + dp[c1][c2]) % MOD;
                        count = (count + dp[c1][c2 + 1]) % MOD;
                        nextDp[c1][c2] = count;
                    }
                }
            }
            dp = nextDp;
        }

        long ans = 0;
        if (n % 2 == 0) { // n is even
            for (int c1 = 1; c1 <= half; c1++) {
                ans = (ans + dp[c1][c1]) % MOD;
                ans = (ans + dp[c1][c1 + 1]) % MOD;
            }
        } else { // n is odd
            int midRow = half + 1;
            for (int c1 = 1; c1 <= half; c1++) {
                 for (int c2 = 1; c2 <= n - half + 1; c2++) {
                    if (dp[c1][c2] > 0) {
                        // Connect through middle row
                        // (half, c1) -> (midRow, c_mid) -> (midRow + 1, c2)
                        // Connections via (midRow, c1)
                        if (c2 == c1 || c2 == c1 + 1) {
                             ans = (ans + dp[c1][c2]) % MOD;
                        }
                        // Connections via (midRow, c1 + 1)
                        if (c2 == c1 + 1 || c2 == c1 + 2) {
                             ans = (ans + dp[c1][c2]) % MOD;
                        }
                    }
                 }
            }
        }
        System.out.println(ans);
    }
}
import sys

def solve():
    n = int(sys.stdin.readline())
    a = [[0] * (n + 1) for _ in range(n + 1)]
    for i in range(1, n + 1):
        row = list(map(int, sys.stdin.readline().split()))
        for j in range(1, i + 1):
            a[i][j] = row[j - 1]

    if n == 1:
        print(1)
        return

    MOD = 10**9 + 7
    
    # dp[c1][c2]
    dp = [[0] * (n + 2) for _ in range(n + 2)]

    # Base case: i = 1
    for c2 in range(1, n + 1):
        if a[1][1] == a[n][c2]:
            dp[1][c2] = 1

    half = n // 2
    for i in range(2, half + 1):
        next_dp = [[0] * (n + 2) for _ in range(n + 2)]
        for c1 in range(1, i + 1):
            for c2 in range(1, n - i + 2):
                if a[i][c1] == a[n - i + 1][c2]:
                    count = 0
                    count = (count + dp[c1 - 1][c2]) % MOD
                    count = (count + dp[c1 - 1][c2 + 1]) % MOD
                    count = (count + dp[c1][c2]) % MOD
                    count = (count + dp[c1][c2 + 1]) % MOD
                    next_dp[c1][c2] = count
        dp = next_dp

    ans = 0
    if n % 2 == 0:  # n is even
        for c1 in range(1, half + 1):
            ans = (ans + dp[c1][c1]) % MOD
            ans = (ans + dp[c1][c1 + 1]) % MOD
    else:  # n is odd
        for c1 in range(1, half + 1):
            for c2 in range(1, n - half + 2):
                if dp[c1][c2] > 0:
                    # Connections
                    if c2 == c1 or c2 == c1 + 2:
                        ans = (ans + dp[c1][c2]) % MOD
                    elif c2 == c1 + 1:
                        ans = (ans + dp[c1][c2] * 2) % MOD
    
    print(ans)

solve()

算法及复杂度

  • 算法: 双端动态规划 (Meet-in-the-middle DP)
  • 时间复杂度: 我们需要迭代 层。在每一层,我们需要填充一个 的 DP 表,每个状态的计算是常数时间。因此,总时间复杂度为
  • 空间复杂度: 我们需要存储 DP 表。通过使用滚动数组(用一个 next_dp 数组来计算当前层),我们可以将空间复杂度优化到