版本一:(python存在超时)

解题思路

这是一个区间DP问题,使用四维状态来解决:

  1. 状态定义:

    • 表示 串区间 串区间 能否组成回文串
    • 串的左右边界, 串的左右边界
  2. 状态转移:

    • 当长度为1时,
    • 对于 串内部:如果 ,则可以从 转移
    • 对于 串内部:如果 ,则可以从 转移
    • 对于 串之间:
      • 如果 ,则可以从 转移
      • 如果 ,则可以从 转移
  3. 答案:

    • 遍历所有可能的区间组合,找到能形成回文串的最大长度

代码

#include <iostream>
#include <cstring>
using namespace std;

int dp[52][52][52][52];

int solve(string& a, string& b) {
    memset(dp, 0, sizeof(dp));
    int s = a.size(), t = b.size();
    int ans = 0;
    
    // 枚举所有可能的区间组合
    for(int x = 0; x <= s; x++)
        for(int y = 0; y <= t; y++)
            for(int i = 1, j = x; j <= s; i++, j++)
                for(int k = 1, l = y; l <= t; k++, l++) {
                    // 长度为1的情况
                    if(x + y <= 1)
                        dp[i][j][k][l] = 1;
                    else {
                        // A串内部匹配
                        if(a[i-1] == a[j-1] && x > 1)
                            dp[i][j][k][l] |= dp[i+1][j-1][k][l];
                        // B串内部匹配
                        if(b[k-1] == b[l-1] && y > 1)
                            dp[i][j][k][l] |= dp[i][j][k+1][l-1];
                        // A、B串之间匹配
                        if(x && y) {
                            if(a[i-1] == b[l-1])
                                dp[i][j][k][l] |= dp[i+1][j][k][l-1];
                            if(a[j-1] == b[k-1])
                                dp[i][j][k][l] |= dp[i][j-1][k+1][l];
                        }
                    }
                    // 更新答案
                    if(dp[i][j][k][l])
                        ans = max(ans, x + y);
                }
    return ans;
}

int main() {
    int n;
    cin >> n;
    while(n--) {
        string a, b;
        cin >> a >> b;
        cout << solve(a, b) << endl;
    }
    return 0;
}
import java.io.*;

public class Main {
    static int[][] dp = new int[52][52];
    static int[][][][] memo = new int[52][52][52][52];
    
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine());
        
        while(n-- > 0) {
            String a = br.readLine();
            String b = br.readLine();
            System.out.println(solve(a, b));
        }
    }
    
    static int solve(String a, String b) {
        int s = a.length(), t = b.length();
        int ans = 0;
        
        // 初始化
        for(int i = 0; i < 52; i++)
            for(int j = 0; j < 52; j++)
                for(int k = 0; k < 52; k++)
                    for(int l = 0; l < 52; l++)
                        memo[i][j][k][l] = 0;
        
        // 枚举区间
        for(int x = 0; x <= s; x++) {
            for(int y = 0; y <= t; y++) {
                for(int i = 1, j = x; j <= s; i++, j++) {
                    for(int k = 1, l = y; l <= t; k++, l++) {
                        if(x + y <= 1) {
                            memo[i][j][k][l] = 1;
                        } else {
                            // A串内部匹配
                            if(x > 1 && a.charAt(i-1) == a.charAt(j-1))
                                memo[i][j][k][l] |= memo[i+1][j-1][k][l];
                            // B串内部匹配
                            if(y > 1 && b.charAt(k-1) == b.charAt(l-1))
                                memo[i][j][k][l] |= memo[i][j][k+1][l-1];
                            // 交叉匹配
                            if(x > 0 && y > 0) {
                                if(a.charAt(i-1) == b.charAt(l-1))
                                    memo[i][j][k][l] |= memo[i+1][j][k][l-1];
                                if(a.charAt(j-1) == b.charAt(k-1))
                                    memo[i][j][k][l] |= memo[i][j-1][k+1][l];
                            }
                        }
                        if(memo[i][j][k][l] == 1)
                            ans = Math.max(ans, x + y);
                    }
                }
            }
        }
        return ans;
    }
}
import sys
input = sys.stdin.buffer.readline

def solve(a, b):
    s, t = len(a), len(b)
    N = 52
    # 使用一维数组,通过计算偏移量来模拟四维
    dp = [0] * (N * N * N * N)
    ans = 0
    
    def idx(i, j, k, l):
        return ((i * N + j) * N + k) * N + l
    
    # 枚举所有可能的区间组合
    for x in range(s + 1):
        for y in range(t + 1):
            i, j = 1, x
            while j <= s:
                k, l = 1, y
                while l <= t:
                    pos = idx(i, j, k, l)
                    if x + y <= 1:
                        dp[pos] = 1
                    else:
                        # A串内部匹配
                        if x > 1 and a[i-1] == a[j-1]:
                            dp[pos] |= dp[idx(i+1, j-1, k, l)]
                        # B串内部匹配
                        if y > 1 and b[k-1] == b[l-1]:
                            dp[pos] |= dp[idx(i, j, k+1, l-1)]
                        # A、B串之间匹配
                        if x and y:
                            if a[i-1] == b[l-1]:
                                dp[pos] |= dp[idx(i+1, j, k, l-1)]
                            if a[j-1] == b[k-1]:
                                dp[pos] |= dp[idx(i, j-1, k+1, l)]
                    
                    if dp[pos]:
                        ans = max(ans, x + y)
                    k += 1
                    l += 1
                i += 1
                j += 1
    return ans

# 主程序
for _ in range(int(input())):
    a = input().strip().decode()
    b = input().strip().decode()
    print(solve(a, b))

算法及复杂度

  • 算法:区间动态规划
  • 时间复杂度: - 分别是两个字符串的长度
  • 空间复杂度: - 四维 数组的大小

版本二:

解题思路

使用状态压缩 DP 解决:

  1. 表示 串区间 串前 个字符能构成的所有可能的回文串状态
  2. 使用二进制位来表示 串中每个位置是否被使用
  3. 通过位运算优化状态转移

代码

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    
    int T, nA, nB, res;
    cin >> T;
    string A, B;
    ll f[52][52][52], gA[52], gB[52], t;
    
    while (T--) {
        cin >> A >> B;
        nA = A.size(), nB = B.size();
        
        // 初始化基础状态
        for (int lA = 0; lA <= nA; ++lA)
            for (int rA = 0; rA <= nA; ++rA)
                for (int lB = 0; lB <= nB; ++lB)
                    f[lA][rA][lB] = rA - lA < 2 ? 
                        (1ll << lB | (lA == rA && lB < nB ? 1ll << lB + 1 : 0)) : 0;
        
        // 预处理A串中每个字符在B串中的出现位置
        for (int i = 1; i <= nA; ++i)
            for (int j = gA[i] = 0; j <= nB - 1; ++j)
                if (A[i - 1] == B[j]) 
                    gA[i] |= 1ll << j;
        
        // 预处理B串中每个字符在B串中的出现位置
        for (int i = 1; i <= nB; ++i)
            for (int j = gB[i] = 0; j <= nB - 1; ++j)
                if (B[i - 1] == B[j]) 
                    gB[i] |= 1ll << j;
        
        res = 1;
        // 状态转移
        for (int lA = nA; lA >= 0; --lA)
            for (int rA = lA; rA <= nA; ++rA)
                for (int lB = nB; lB >= 0; --lB)
                    if (t = f[lA][rA][lB]) {
                        if (lA) {
                            f[lA - 1][rA][lB] |= (t & gA[lA]) << 1;
                            if (A[lA - 1] == A[rA]) 
                                f[lA - 1][rA + 1][lB] |= t;
                        }
                        if (lB) {
                            f[lA][rA][lB - 1] |= (t & gB[lB]) << 1;
                            if (B[lB - 1] == A[rA]) 
                                f[lA][rA + 1][lB - 1] |= t;
                        }
                        res = max(res, rA - lA + 63 - __builtin_clzll(t) - lB);
                    }
        cout << res << endl;
    }
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static long[][] f = new long[52][52];
    static long[] gA = new long[52];
    static long[] gB = new long[52];
    
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int T = Integer.parseInt(br.readLine());
        
        while (T-- > 0) {
            String A = br.readLine();
            String B = br.readLine();
            int nA = A.length(), nB = B.length();
            long[][][] f = new long[52][52][52];
            
            // 初始化基础状态
            for (int lA = 0; lA <= nA; ++lA)
                for (int rA = 0; rA <= nA; ++rA)
                    for (int lB = 0; lB <= nB; ++lB)
                        f[lA][rA][lB] = rA - lA < 2 ? 
                            (1L << lB | (lA == rA && lB < nB ? 1L << (lB + 1) : 0)) : 0;
            
            // 预处理A串中每个字符在B串中的出现位置
            for (int i = 1; i <= nA; ++i) {
                gA[i] = 0;
                for (int j = 0; j <= nB - 1; ++j)
                    if (A.charAt(i - 1) == B.charAt(j)) 
                        gA[i] |= 1L << j;
            }
            
            // 预处理B串中每个字符在B串中的出现位置
            for (int i = 1; i <= nB; ++i) {
                gB[i] = 0;
                for (int j = 0; j <= nB - 1; ++j)
                    if (B.charAt(i - 1) == B.charAt(j)) 
                        gB[i] |= 1L << j;
            }
            
            int res = 1;
            // 状态转移
            for (int lA = nA; lA >= 0; --lA)
                for (int rA = lA; rA <= nA; ++rA)
                    for (int lB = nB; lB >= 0; --lB) {
                        long t = f[lA][rA][lB];
                        if (t != 0) {
                            if (lA > 0) {
                                f[lA - 1][rA][lB] |= (t & gA[lA]) << 1;
                                if (lA > 0 && rA < nA && A.charAt(lA - 1) == A.charAt(rA))
                                    f[lA - 1][rA + 1][lB] |= t;
                            }
                            if (lB > 0) {
                                f[lA][rA][lB - 1] |= (t & gB[lB]) << 1;
                                if (rA < nA && B.charAt(lB - 1) == A.charAt(rA))
                                    f[lA][rA + 1][lB - 1] |= t;
                            }
                            res = Math.max(res, rA - lA + 63 - Long.numberOfLeadingZeros(t) - lB);
                        }
                    }
            System.out.println(res);
        }
    }
}
import sys
input = sys.stdin.buffer.readline

def solve(A, B):
    nA, nB = len(A), len(B)
    # 创建固定大小的数组
    f = [[[0]*52 for _ in range(52)] for _ in range(52)]
    gA = [0]*52
    gB = [0]*52
    
    # 初始化基础状态
    for lA in range(nA + 1):
        for rA in range(nA + 1):
            for lB in range(nB + 1):
                if rA - lA < 2:
                    f[lA][rA][lB] = (1 << lB) | ((1 << (lB + 1)) if (lA == rA and lB < nB) else 0)
                else:
                    f[lA][rA][lB] = 0
    
    # 预处理A串中每个字符在B串中的出现位置
    for i in range(1, nA + 1):
        gA[i] = 0
        for j in range(nB):
            if A[i - 1] == B[j]:
                gA[i] |= 1 << j
    
    # 预处理B串中每个字符在B串中的出现位置
    for i in range(1, nB + 1):
        gB[i] = 0
        for j in range(nB):
            if B[i - 1] == B[j]:
                gB[i] |= 1 << j
    
    res = 1
    # 状态转移
    for lA in range(nA, -1, -1):
        for rA in range(lA, nA + 1):
            for lB in range(nB, -1, -1):
                t = f[lA][rA][lB]
                if t:
                    if lA > 0:  # 修改条件
                        f[lA - 1][rA][lB] |= (t & gA[lA]) << 1
                        if rA < nA and A[lA - 1] == A[rA]:  # 添加边界检查
                            f[lA - 1][rA + 1][lB] |= t
                    if lB > 0:  # 修改条件
                        f[lA][rA][lB - 1] |= (t & gB[lB]) << 1
                        if rA < nA and B[lB - 1] == A[rA]:  # 添加边界检查
                            f[lA][rA + 1][lB - 1] |= t
                    
                    # 计算前导零
                    if t:
                        leading_zeros = 0
                        temp = t
                        while temp and (temp & (1 << 63)) == 0:
                            leading_zeros += 1
                            temp <<= 1
                        res = max(res, rA - lA + 63 - leading_zeros - lB)
    
    return res

# 主程序
T = int(input())
for _ in range(T):
    A = input().strip().decode()
    B = input().strip().decode()
    print(solve(A, B))

算法及复杂度

  • 算法:状态压缩动态规划
  • 时间复杂度:
  • 空间复杂度:

注意: 也可以进一步优化成一维的空间。