题目描述

折叠的定义如下:

一个字符串可以看成它自身的折叠。记作S = S
X(S)是X(X>1)个S连接在一起的串的折叠。记作X(S) = SSSS…S(X个S)。
如果A = A’, B = B’,则AB = A’B’ 例如,因为3(A) = AAA, 2(B) = BB,所以3(A)C2(B) = AAACBB,而2(3(A)C)2(B) = AAACAAACBB
给一个字符串,求它的最短折叠。例如AAAAAAAAAABABABCCD的最短折叠为:9(A)3(AB)CCD。

输入格式

仅一行,即字符串S,长度保证不超过100。

输出格式

仅一行,即最短的折叠长度。

输入输出样例
输入 #1

NEERCYESYESYESNEERCYESYESYES

输出 #1

14

说明/提示

一个最短的折叠为:2(NEERC3(YES))

分析

一个区间的最短长度由两种情况得来。
①:由两个子区间合并而来
②:由某个循环节构成
于是我们用 f [ i ] [ j ] f[i][j] f[i][j] 表示区间 [ i , j ] [i,j] [i,j] 的最短长度, t t t 表示 循环节长度, s 1 [ i ] s1[i] s1[i] 表示循环节个数的位数

于是得到转移方程
f [ i ] [ j ] = m i n ( f [ i ] [ j ] , f [ i ] [ k ] + f [ k + 1 ] [ j ] ) <mtext>    </mtext> k [ i , j ) f[i][j] = min(f[i][j], f[i][k] + f[k+1][j]) ~~k\in [i,j) f[i][j]=min(f[i][j],f[i][k]+f[k+1][j])  k[i,j)
i f ( c h e c k ( i , j , t ) ) <mtext>     </mtext> f [ i ] [ j ] = m i n ( f [ i ] [ j ] , s 1 [ l / t ] + 2 + f [ i ] [ i + t 1 ] ) if(check(i, j, t))~~~ f[i][j] = min(f[i][j], s1[l / t] + 2 + f[i][i+t-1]) if(check(i,j,t))   f[i][j]=min(f[i][j],s1[l/t]+2+f[i][i+t1])

我是来说说复杂度的

咋一看复杂度是 O ( n 4 ) O(n^4) O(n4) ,其实不然。
我们记 τ ( i ) \tau(i) τ(i) i i i 的约数个数
那么复杂度是 O ( n 2 i = 1 n τ ( i ) ) O(n^2*\sum\limits_{i=1}^{n}\tau(i)) O(n2i=1nτ(i))
i = 1 n τ ( i ) \sum\limits_{i=1}^{n} \tau(i) i=1nτ(i),转化为考虑 i i i 作为约数对答案的贡献,也就是 i i i 的倍数
这样的话, i = 1 n τ ( i ) = i = 1 n n i &lt; i = 1 n n i &lt; n l n n &lt; n l o g n \sum\limits_{i=1}^{n} \tau(i)=\sum\limits_{i=1}^{n} \left\lfloor\dfrac{n}{i}\right\rfloor &lt; \sum\limits_{i=1}^{n}\dfrac{n}{i} &lt; nln{n}&lt;nlogn i=1nτ(i)=i=1nin<i=1nin<nlnn<nlogn
后面那个是调和级数,为啥是这样,因为 1 i \dfrac{1}{i} i1 的原函数是 l n <mtext>   </mtext> i ln~i ln i
因此复杂度是 O ( n 3 l o g n ) O(n^3logn) O(n3logn)

代码如下

#include <bits/stdc++.h>
using namespace std;
int s1[104], f[104][104];
char s[105];
int check(int l, int r, int len){
	for(int i = l; i <= r; i++){
		if(i + len <= r && s[i] != s[i + len]) return 0;
	}
	return 1;
}
int main(){
	int i, j, n, m, k, t, l;
	for(i = 2; i <= 9; i++) s1[i] = 1;
	for(i = 10; i <= 99; i++) s1[i] = 2;
	s1[100] = 3;
	scanf("%s", s);
	n = strlen(s);
	memset(f, 1, sizeof(f));
	for(i = 0; i < n; i++) f[i][i] = 1;
	for(i = n - 1; i >= 0; i--){
		for(j = i + 1; j < n; j++){
			l = j - i + 1;
			for(k = i; k < j; k++) f[i][j] = min(f[i][j], f[i][k] + f[k+1][j]);
			for(k = i; k < j; k++){
				t = k - i + 1;
				if(l % t != 0) continue;
				if(check(i, j, t)) f[i][j] = min(f[i][j], s1[l / t] + 2 + f[i][k]);
			}
		}
	}
	printf("%d", f[0][n-1]);
	return 0;
}