题目链接

小红的回文子串

题目描述

给定一个长度为 的字符串,我们想知道,该字符串有多少个长度为 的连续子串是回文串?

解题思路

这个问题要求我们统计固定长度的回文子串数量。一个直接的方法是暴力枚举,但效率较低。更高效的解决方案是使用字符串哈希。

1. 暴力解法 (时间复杂度 )

我们可以遍历所有 个长度为 的子串。对于每个子串,我们花费 的时间来检查它是否是回文。总时间复杂度为 ,当 很大时,该方法会超时。

2. 字符串哈希 (时间复杂度 )

核心思想是:一个字符串是回文串,当且仅当它的哈希值与其反转串的哈希值相等。这使我们能将比较字符串的 操作优化为比较哈希值的 操作。

  • 哈希函数:我们使用多项式滚动哈希。对于一个字符串 ,其哈希值 ,其中 是一个质数基底, 是一个大质数模数。

  • 实现步骤

    1. 预计算

      • 预计算基底 的各次幂 ,以便快速计算。

      • 计算原字符串 的所有前缀的哈希值,存入数组 h_fwd

      • 将字符串 反转得到 ,计算其所有前缀的哈希值,存入数组 h_rev

    2. 子串哈希计算:利用前缀哈希数组,我们可以在 内计算出任意子串 的哈希值。

    3. 遍历与比较

      • 我们遍历所有子串的起始位置 (从 ) 。

      • 对于每个子串 ,我们计算它的正向哈希值

      • 这个子串的反转形式,对应于反转字符串 中的子串 。我们计算这个子串的反向哈希值

      • 如果两个哈希值相等,说明原子串是回文串,我们将计数器加一。

  • 避免哈希碰撞:为了将哈希碰撞的概率降到极低,我们采用双哈希方法,即使用两组不同的基底和模数计算哈希值。只有当两对哈希值都相等时,我们才认为子串是回文的。在C++中,使用 unsigned long long 可以利用其自动溢出的特性,相当于对 取模,是一种高效的单哈希实现方式,通常足以应对竞赛中的数据。

代码

#include <bits/stdc++.h>

using namespace std;
using ull = unsigned long long;

const int P = 131; // 质数基底

// h[i] 存储前 i 个字符的哈希值
ull h_fwd[1000005], h_rev[1000005]; 
ull p_pow[1000005];

// 获取字符串 s 的子串 s[l..r] 的哈希值
ull get_hash(ull h[], int l, int r) {
    return h[r] - h[l - 1] * p_pow[r - l + 1];
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, k;
    cin >> n >> k;
    string s;
    cin >> s;
    string s_rev = s;
    reverse(s_rev.begin(), s_rev.end());

    p_pow[0] = 1;
    for (int i = 1; i <= n; ++i) {
        p_pow[i] = p_pow[i - 1] * P;
        h_fwd[i] = h_fwd[i - 1] * P + (s[i - 1] - 'a' + 1);
        h_rev[i] = h_rev[i - 1] * P + (s_rev[i - 1] - 'a' + 1);
    }

    int count = 0;
    for (int i = 1; i <= n - k + 1; ++i) {
        // 子串 s[i..i+k-1]
        int l1 = i, r1 = i + k - 1;
        // 对应的反转串中的子串
        int l2 = n - (i + k - 1) + 1, r2 = n - i + 1;
        
        if (get_hash(h_fwd, l1, r1) == get_hash(h_rev, l2, r2)) {
            count++;
        }
    }

    cout << count << endl;

    return 0;
}
import java.util.Scanner;

public class Main {
    private static final long P1 = 31, P2 = 37;
    private static final long M1 = 1_000_000_007, M2 = 1_000_000_009;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int k = sc.nextInt();
        String s = sc.next();
        String sRev = new StringBuilder(s).reverse().toString();

        long[] p1Pow = new long[n + 1];
        long[] p2Pow = new long[n + 1];
        p1Pow[0] = 1;
        p2Pow[0] = 1;

        long[][] hFwd = new long[n + 1][2];
        long[][] hRev = new long[n + 1][2];

        for (int i = 1; i <= n; i++) {
            p1Pow[i] = (p1Pow[i - 1] * P1) % M1;
            p2Pow[i] = (p2Pow[i - 1] * P2) % M2;
            
            hFwd[i][0] = (hFwd[i - 1][0] * P1 + (s.charAt(i - 1) - 'a' + 1)) % M1;
            hFwd[i][1] = (hFwd[i - 1][1] * P2 + (s.charAt(i - 1) - 'a' + 1)) % M2;
            
            hRev[i][0] = (hRev[i - 1][0] * P1 + (sRev.charAt(i - 1) - 'a' + 1)) % M1;
            hRev[i][1] = (hRev[i - 1][1] * P2 + (sRev.charAt(i - 1) - 'a' + 1)) % M2;
        }

        int count = 0;
        for (int i = 0; i <= n - k; i++) {
            long[] hashFwd = getHash(hFwd, p1Pow, p2Pow, i, i + k - 1, k);
            
            int revI = n - (i + k - 1) - 1;
            long[] hashRev = getHash(hRev, p1Pow, p2Pow, revI, revI + k - 1, k);
            
            if (hashFwd[0] == hashRev[0] && hashFwd[1] == hashRev[1]) {
                count++;
            }
        }
        System.out.println(count);
    }

    private static long[] getHash(long[][] h, long[] p1Pow, long[] p2Pow, int l, int r, int k) {
        long hash1 = (h[r + 1][0] - (h[l][0] * p1Pow[k]) % M1 + M1) % M1;
        long hash2 = (h[r + 1][1] - (h[l][1] * p2Pow[k]) % M2 + M2) % M2;
        return new long[]{hash1, hash2};
    }
}
import sys

def solve():
    n, k = map(int, sys.stdin.readline().split())
    s = sys.stdin.readline().strip()
    s_rev = s[::-1]

    P1, M1 = 31, 10**9 + 7
    P2, M2 = 37, 10**9 + 9

    p1_pow = [1] * (n + 1)
    p2_pow = [1] * (n + 1)
    for i in range(1, n + 1):
        p1_pow[i] = (p1_pow[i - 1] * P1) % M1
        p2_pow[i] = (p2_pow[i - 1] * P2) % M2

    h_fwd = [(0, 0)] * (n + 1)
    h_rev = [(0, 0)] * (n + 1)
    for i in range(n):
        val = ord(s[i]) - ord('a') + 1
        h_fwd[i+1] = (
            (h_fwd[i][0] * P1 + val) % M1,
            (h_fwd[i][1] * P2 + val) % M2
        )
        val_rev = ord(s_rev[i]) - ord('a') + 1
        h_rev[i+1] = (
            (h_rev[i][0] * P1 + val_rev) % M1,
            (h_rev[i][1] * P2 + val_rev) % M2
        )

    def get_hash(h, l, r):
        len_sub = r - l + 1
        h1 = (h[r+1][0] - (h[l][0] * p1_pow[len_sub]) % M1 + M1) % M1
        h2 = (h[r+1][1] - (h[l][1] * p2_pow[len_sub]) % M2 + M2) % M2
        return h1, h2

    count = 0
    for i in range(n - k + 1):
        hash_fwd = get_hash(h_fwd, i, i + k - 1)
        
        rev_i = n - (i + k - 1) - 1
        hash_rev = get_hash(h_rev, rev_i, rev_i + k - 1)
        
        if hash_fwd == hash_rev:
            count += 1
            
    print(count)

solve()

算法及复杂度

  • 算法:字符串双哈希(滚动哈希)

  • 时间复杂度:。预计算哈希数组和幂的数组需要 ,之后遍历所有子串并进行 的哈希比较,总共需要

  • 空间复杂度:,用于存储预计算的哈希值和幂。