题目链接

子串匹配

题目描述

给定一个文本串 和一个模式串 。任务是:

  1. 按升序输出 中的所有出现位置(1-indexed)。
  2. 输出模式串 的前缀函数(题目中称为 border 数组)。

解题思路

本题是 Knuth-Morris-Pratt (KMP) 字符串匹配算法的模板题。KMP 算法的核心思想是利用模式串自身的性质,在发生不匹配时,尽可能地向右“滑动”模式串,从而避免文本串指针的回溯,达到线性时间复杂度。整个算法分为两步。

1. 预处理模式串:计算前缀函数(Border 数组)

首先,我们需要为模式串 计算一个前缀函数,通常记为 数组。对于模式串 的每个前缀 的值是该前缀的最长真前缀(proper prefix)与真后缀(proper suffix)的相等长度。这正是题目要求的第二部分。

计算方法是一个动态规划过程:

  • 总是
  • 为了计算 ,我们利用已经计算出的 。令
    • 如果 ,说明长度为 的 border 可以扩展一位,因此
    • 如果 ,说明无法扩展。我们需要寻找一个更短的 border。下一个可能的 border 长度就是 的 border 长度,即 。我们令 并重复比较,直到 或找到匹配。

2. 在文本串中匹配

有了 数组后,我们就可以在文本串 中进行高效匹配。我们使用两个指针: 用于遍历文本串 用于遍历模式串 同时也代表当前已匹配的模式串前缀的长度。

  • 我们从 开始,比较
  • : 两个指针都向前移动一位,
  • :
    • 如果 ,说明之前有部分匹配。我们不想完全重新开始,而是利用 数组。令 ,这相当于将模式串向右滑动,使得一个更短的前缀对齐到当前位置,然后继续比较 和新的 。注意,文本串指针 不动。
    • 如果 ,说明模式串的第一个字符就不匹配,我们只能将文本串指针向前移动一位,
  • : 说明我们找到了一个完整的匹配。记录下起始位置(,注意题目要求1-indexed)。然后,为了继续寻找后续可能的匹配,我们令 ,继续匹配过程。

这个过程一直持续到 遍历完整个文本串

代码

#include <iostream>
#include <vector>
#include <string>
#include <numeric>

using namespace std;

// 计算前缀函数
vector<int> compute_pi(const string& p) {
    int m = p.length();
    vector<int> pi(m, 0);
    for (int i = 1; i < m; i++) {
        int j = pi[i - 1];
        while (j > 0 && p[i] != p[j]) {
            j = pi[j - 1];
        }
        if (p[i] == p[j]) {
            j++;
        }
        pi[i] = j;
    }
    return pi;
}

int main() {
    string t, p;
    cin >> t >> p;

    int n = t.length();
    int m = p.length();

    if (m == 0) return 0;

    vector<int> pi = compute_pi(p);
    vector<int> occurrences;
    
    int j = 0; // 模式串的指针
    for (int i = 0; i < n; i++) { // 文本串的指针
        while (j > 0 && t[i] != p[j]) {
            j = pi[j - 1];
        }
        if (t[i] == p[j]) {
            j++;
        }
        if (j == m) {
            occurrences.push_back(i - m + 2); // +1 for 0-indexed to 1-indexed, +1 for start position
            j = pi[j - 1];
        }
    }

    for (int pos : occurrences) {
        cout << pos << endl;
    }

    for (int i = 0; i < m; i++) {
        cout << pi[i] << (i == m - 1 ? "" : " ");
    }
    cout << endl;

    return 0;
}
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;

public class Main {
    // 计算前缀函数
    private static int[] computePi(String p) {
        int m = p.length();
        int[] pi = new int[m];
        for (int i = 1; i < m; i++) {
            int j = pi[i - 1];
            while (j > 0 && p.charAt(i) != p.charAt(j)) {
                j = pi[j - 1];
            }
            if (p.charAt(i) == p.charAt(j)) {
                j++;
            }
            pi[i] = j;
        }
        return pi;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        String t = sc.next();
        String p = sc.next();

        int n = t.length();
        int m = p.length();

        if (m == 0) return;

        int[] pi = computePi(p);
        List<Integer> occurrences = new ArrayList<>();
        
        int j = 0; // 模式串的指针
        for (int i = 0; i < n; i++) { // 文本串的指针
            while (j > 0 && t.charAt(i) != p.charAt(j)) {
                j = pi[j - 1];
            }
            if (t.charAt(i) == p.charAt(j)) {
                j++;
            }
            if (j == m) {
                occurrences.add(i - m + 2); // +1 for 0-indexed to 1-indexed, +1 for start position
                j = pi[j - 1];
            }
        }

        for (int pos : occurrences) {
            System.out.println(pos);
        }

        for (int i = 0; i < m; i++) {
            System.out.print(pi[i] + (i == m - 1 ? "" : " "));
        }
        System.out.println();
    }
}
def compute_pi(p):
    m = len(p)
    pi = [0] * m
    for i in range(1, m):
        j = pi[i - 1]
        while j > 0 and p[i] != p[j]:
            j = pi[j - 1]
        if p[i] == p[j]:
            j += 1
        pi[i] = j
    return pi

def solve():
    t = input()
    p = input()
    
    n = len(t)
    m = len(p)

    if m == 0:
        return

    pi = compute_pi(p)
    occurrences = []
    
    j = 0 # 模式串的指针
    for i in range(n): # 文本串的指针
        while j > 0 and t[i] != p[j]:
            j = pi[j - 1]
        if t[i] == p[j]:
            j += 1
        if j == m:
            occurrences.append(i - m + 2) # +1 for 0-indexed to 1-indexed, +1 for start position
            j = pi[j - 1]
            
    for pos in occurrences:
        print(pos)
        
    print(*pi)

solve()

算法及复杂度

  • 算法:KMP 字符串匹配算法
  • 时间复杂度:
    • 计算模式串 的前缀函数需要 的时间。
    • 在文本串 中进行匹配需要 的时间。
    • 总时间复杂度是两者之和。
  • 空间复杂度:
    • 需要额外空间存储模式串 的前缀函数 数组。