题目链接
题目描述
给定一个仅由小写字母 x
和 y
组成的字符串。
一次操作可以将字符串中的一个子串 xy
替换为 yyx
。
求至少需要多少次替换,才能让字符串中不存在子串 xy
?结果需要对 10^9 + 7
取模。
解题思路 (修正版)
这是一个巧妙的递推问题。关键在于选择正确的遍历方向。从左到右处理会导致复杂的依赖关系,而从右到左处理则可以建立一个清晰的递推模型。
1. 核心思想:从右向左的动态规划
最终的目标字符串形态必然是所有 y
都在所有 x
的左边(形如 yy...yxx...x
)。这意味着每一个 x
都必须移动到每一个 y
的右边。
我们从右向左遍历字符串,并维护两个状态变量:
ans
:到目前位置(从右到左)累计需要的总操作次数。y_count
:在当前位置右边,有效y
的数量(包括原始的y
和由操作新生成的y
)。
2. 状态转移
当我们从右向左处理到第 i
个字符时:
-
如果
s[i] == 'y'
: 我们在字符串中遇到了一个原始的y
。它为我们右侧的y
群体贡献了一个成员。所以,我们只需将y_count
加 1。y_count = (y_count + 1) % MOD
-
如果
s[i] == 'x'
: 这个x
必须移动到它右侧所有y_count
个y
的后面。- 计算操作数:
x
每经过一个y
(通过一次xy -> yyx
操作),就需要1
次操作。因此,要经过右侧全部y_count
个y
,就需要y_count
次操作。所以,我们将ans
加上y_count
。ans = (ans + y_count) % MOD
- 更新
y
的数量:当这个x
经过了右侧所有的y
之后,根据规则xy -> yyx
,每一个原始的y
都变成了yy
。这意味着右侧y
的有效数量翻了一倍。所以,我们将y_count
乘以 2。y_count = (y_count * 2) % MOD
- 计算操作数:
3. 算法步骤
- 初始化
ans = 0
(总替换次数),y_count = 0
(右侧y
的数量)。 - 初始化模数
MOD = 10^9 + 7
。 - 从右到左遍历字符串
s
:- 如果
s[i] == 'y'
,则y_count = (y_count + 1) % MOD
。 - 如果
s[i] == 'x'
,则ans = (ans + y_count) % MOD
,然后y_count = (y_count * 2) % MOD
。
- 如果
- 遍历结束后,
ans
即为最终答案。
这个算法只需要一次线性扫描,非常高效。
代码
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
string s;
cin >> s;
long long ans = 0;
long long y_count = 0;
long long mod = 1e9 + 7;
for (int i = s.length() - 1; i >= 0; --i) {
if (s[i] == 'y') {
y_count++;
} else { // s[i] == 'x'
ans = (ans + y_count) % mod;
y_count = (y_count * 2) % mod;
}
}
cout << ans << endl;
return 0;
}
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
String s = sc.next();
long ans = 0;
long yCount = 0;
long mod = 1_000_000_007;
for (int i = s.length() - 1; i >= 0; i--) {
char c = s.charAt(i);
if (c == 'y') {
yCount++;
} else { // c == 'x'
ans = (ans + yCount) % mod;
yCount = (yCount * 2) % mod;
}
}
System.out.println(ans);
}
}
import sys
def solve():
try:
s = sys.stdin.readline().strip()
if not s:
return
ans = 0
y_count = 0
mod = 10**9 + 7
for char in reversed(s):
if char == 'y':
y_count += 1
else: # char == 'x'
ans = (ans + y_count) % mod
y_count = (y_count * 2) % mod
print(ans)
except (IOError, ValueError):
return
solve()
算法及复杂度
-
算法:动态规划 / 递推
-
时间复杂度:
,其中
N
是输入字符串的长度。我们只需要对字符串进行一次线性扫描。 -
空间复杂度:
。我们只使用了常数个额外的变量。