题目链接

字符串替换

题目描述

给定一个仅由小写字母 xy 组成的字符串。

一次操作可以将字符串中的一个子串 xy 替换为 yyx

求至少需要多少次替换,才能让字符串中不存在子串 xy?结果需要对 10^9 + 7 取模。

解题思路 (修正版)

这是一个巧妙的递推问题。关键在于选择正确的遍历方向。从左到右处理会导致复杂的依赖关系,而从右到左处理则可以建立一个清晰的递推模型。

1. 核心思想:从右向左的动态规划

最终的目标字符串形态必然是所有 y 都在所有 x 的左边(形如 yy...yxx...x)。这意味着每一个 x 都必须移动到每一个 y 的右边。

我们从右向左遍历字符串,并维护两个状态变量:

  • ans:到目前位置(从右到左)累计需要的总操作次数。
  • y_count:在当前位置右边,有效 y 的数量(包括原始的 y 和由操作新生成的 y)。

2. 状态转移

当我们从右向左处理到第 i 个字符时:

  • 如果 s[i] == 'y': 我们在字符串中遇到了一个原始的 y。它为我们右侧的 y 群体贡献了一个成员。所以,我们只需将 y_count 加 1。 y_count = (y_count + 1) % MOD

  • 如果 s[i] == 'x': 这个 x 必须移动到它右侧所有 y_county 的后面。

    • 计算操作数x 每经过一个 y(通过一次 xy -> yyx 操作),就需要 1 次操作。因此,要经过右侧全部 y_county,就需要 y_count 次操作。所以,我们将 ans 加上 y_countans = (ans + y_count) % MOD
    • 更新 y 的数量:当这个 x 经过了右侧所有的 y 之后,根据规则 xy -> yyx,每一个原始的 y 都变成了 yy。这意味着右侧 y 的有效数量翻了一倍。所以,我们将 y_count 乘以 2。 y_count = (y_count * 2) % MOD

3. 算法步骤

  1. 初始化 ans = 0 (总替换次数),y_count = 0 (右侧 y 的数量)。
  2. 初始化模数 MOD = 10^9 + 7
  3. 从右到左遍历字符串 s
    • 如果 s[i] == 'y',则 y_count = (y_count + 1) % MOD
    • 如果 s[i] == 'x',则 ans = (ans + y_count) % MOD,然后 y_count = (y_count * 2) % MOD
  4. 遍历结束后,ans 即为最终答案。

这个算法只需要一次线性扫描,非常高效。

代码

#include <iostream>
#include <string>
#include <vector>
#include <algorithm>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    string s;
    cin >> s;

    long long ans = 0;
    long long y_count = 0;
    long long mod = 1e9 + 7;

    for (int i = s.length() - 1; i >= 0; --i) {
        if (s[i] == 'y') {
            y_count++;
        } else { // s[i] == 'x'
            ans = (ans + y_count) % mod;
            y_count = (y_count * 2) % mod;
        }
    }

    cout << ans << endl;

    return 0;
}
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        String s = sc.next();

        long ans = 0;
        long yCount = 0;
        long mod = 1_000_000_007;

        for (int i = s.length() - 1; i >= 0; i--) {
            char c = s.charAt(i);
            if (c == 'y') {
                yCount++;
            } else { // c == 'x'
                ans = (ans + yCount) % mod;
                yCount = (yCount * 2) % mod;
            }
        }

        System.out.println(ans);
    }
}
import sys

def solve():
    try:
        s = sys.stdin.readline().strip()
        if not s:
            return
            
        ans = 0
        y_count = 0
        mod = 10**9 + 7
        
        for char in reversed(s):
            if char == 'y':
                y_count += 1
            else: # char == 'x'
                ans = (ans + y_count) % mod
                y_count = (y_count * 2) % mod
                
        print(ans)

    except (IOError, ValueError):
        return

solve()

算法及复杂度

  • 算法:动态规划 / 递推

  • 时间复杂度: ,其中 N 是输入字符串的长度。我们只需要对字符串进行一次线性扫描。

  • 空间复杂度: 。我们只使用了常数个额外的变量。