牛客7502I - Subsequence Pair

题意

  • 给出两个字符串 SSTTS2000|S|\leq2000T2000|T|\leq2000)。
  • 需要从 SSTT 中分别选出一个子序列 xxyy
  • 要求:xx 的字典序 << yy 的字典序,( 以下记为 lex(x)<lex(y)\text{lex}(x)<\text{lex}(y) )。
  • 求最大的 x+y|x|+|y|

思路

发现性质

  • 我们能发现,如果 lex(x)=lex(y)\text{lex}(x)=\text{lex}(y),那么接下来可能有 lex(x)<lex(y)\text{lex}(x)<\text{lex}(y)
  • 而如果 lex(x)<lex(y)\text{lex}(x)<\text{lex}(y),那么接下来只可能 lex(x)<lex(y)\text{lex}(x)<\text{lex}(y)

错误思路

  • dp[i][j][0/1]dp[i][j][0/1] 代表 SS 匹配到第 ii 位,TT 匹配到第 jj 位,选出来的子序列字典序 x=yx=y 还是 x<yx<ydpdp 值代表最大的 x+y|x|+|y|
  • 这样,转移有:
    • dp[i][j][0]dp[p][q][0]dp[i][j][0]\leftarrow dp[p][q][0]
    • dp[i][j][1]dp[p][q][0]dp[i][j][1]\leftarrow dp[p][q][0]
    • dp[i][j][1]dp[p][q][1]dp[i][j][1]\leftarrow dp[p][q][1]

错误原因

  • 注意看转移:dp[i][j][1]dp[p][q][1]dp[i][j][1]\leftarrow dp[p][q][1]
  • 考虑这种情况:
    • S=S=abcdeT=T=abcde
    • 某状态下 S5=T5S_5=T_5 ,并且 dp[p][q][1]dp[p][q][1] 代表的字符串 xxabyyabcd,满足 lex(x)<lex(y)\text{lex}(x)<\text{lex}(y) 的性质。
    • 如果追加e,那么 xxabeyyabcde,显然不再满足 lex(x)<lex(y)\text{lex}(x)<\text{lex}(y) 的性质。
  • 错误的原因:这样的状态不足以表示字符串信息。

正确思路

  • 先正向做一遍最长公共子序列,存入 dpdp 数组。
  • 再反向做一遍DP,记 f[i][j]f[i][j] 为如果在 iijj 处出现 Si<TjS_i<T_j,但之前字典序相等,那么在这之后能匹配的最长的长度和。
  • Si=TjS_i=T_j 时,ans=max(ans,dp[i][j]+f[i+1][j+1])ans=\max(ans,dp[i][j]+f[i+1][j+1])

代码

#include <cstdio>
#include <iostream>
#include <cstring>
const int N		= 2010;
const int INF	= 1e9;
using namespace std;

int f[N][N],f2[N][N];
int dp[N][N];
char S[N], T[N];
int n, m;

void Sol()
{
	for (int i=0; i<=n+5; i++)
	{
		for (int j=0; j<=m+5; j++)
		{
			dp[i][j] = f[i][j] = 0;
			f2[i][j] = -1;
		}
	}
	
	
	for (int i=n; i>=1; i--)
	{
		for (int j=m; j>=1; j--)
		{
			if(S[i]==T[j])
				f2[i][j] = f2[i+1][j+1];
			else if(S[i]<T[j])
				f2[i][j] = 1;
			else if(S[i]>T[j])
				f2[i][j] = 2;
			
			if(f2[i][j]==0 || f2[i][j]==1)
			{
				f[i][j] = n-i+1 + m-j+1;
			}
		}
	}
	
	for (int i=n+1; i>=1; i--)
	{
		for (int j=m; j>=1; j--)
		{
			f[i][j] = max(f[i][j], max(f[i+1][j], f[i][j+1]));
			f[i][j] = max(f[i][j], m-j+1);
		}
	}
	
	int ans = max(f[1][1],m);
	for (int i=1; i<=n; i++)
	{
		for (int j=1; j<=m; j++)
		{
			if(S[i] == T[j])
			{
				dp[i][j] = max(dp[i][j], dp[i-1][j-1]+1);
				ans = max(ans, dp[i][j]*2 + f[i+1][j+1]);
			}
			else
			{
				dp[i][j] = max(dp[i-1][j], dp[i][j-1]);
			}
		}
	}
	
	printf("%d\n",ans);
}

int main()
{
	while (scanf("%s %s",S+1, T+1)!=EOF)
	{
		n = strlen(S+1);
		m = strlen(T+1);
		Sol();
	}
	
	return 0;
}