题目链接
题目描述
给定一个长度为 的字符串,我们想知道,该字符串有多少个长度为
的连续子串是回文串?
解题思路
这个问题要求我们统计固定长度的回文子串数量。一个直接的方法是暴力枚举,但效率较低。更高效的解决方案是使用字符串哈希。
1. 暴力解法 (时间复杂度 )
我们可以遍历所有 个长度为
的子串。对于每个子串,我们花费
的时间来检查它是否是回文。总时间复杂度为
,当
和
很大时,该方法会超时。
2. 字符串哈希 (时间复杂度 )
核心思想是:一个字符串是回文串,当且仅当它的哈希值与其反转串的哈希值相等。这使我们能将比较字符串的 操作优化为比较哈希值的
操作。
-
哈希函数:我们使用多项式滚动哈希。对于一个字符串
,其哈希值
,其中
是一个质数基底,
是一个大质数模数。
-
实现步骤:
-
预计算:
-
预计算基底
的各次幂
,以便快速计算。
-
计算原字符串
的所有前缀的哈希值,存入数组
h_fwd
。 -
将字符串
反转得到
,计算其所有前缀的哈希值,存入数组
h_rev
。
-
-
子串哈希计算:利用前缀哈希数组,我们可以在
内计算出任意子串
的哈希值。
-
遍历与比较:
-
我们遍历所有子串的起始位置
(从
到
) 。
-
对于每个子串
,我们计算它的正向哈希值。
-
这个子串的反转形式,对应于反转字符串
中的子串
。我们计算这个子串的反向哈希值。
-
如果两个哈希值相等,说明原子串是回文串,我们将计数器加一。
-
-
-
避免哈希碰撞:为了将哈希碰撞的概率降到极低,我们采用双哈希方法,即使用两组不同的基底和模数计算哈希值。只有当两对哈希值都相等时,我们才认为子串是回文的。在C++中,使用
unsigned long long
可以利用其自动溢出的特性,相当于对取模,是一种高效的单哈希实现方式,通常足以应对竞赛中的数据。
代码
#include <bits/stdc++.h>
using namespace std;
using ull = unsigned long long;
const int P = 131; // 质数基底
// h[i] 存储前 i 个字符的哈希值
ull h_fwd[1000005], h_rev[1000005];
ull p_pow[1000005];
// 获取字符串 s 的子串 s[l..r] 的哈希值
ull get_hash(ull h[], int l, int r) {
return h[r] - h[l - 1] * p_pow[r - l + 1];
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, k;
cin >> n >> k;
string s;
cin >> s;
string s_rev = s;
reverse(s_rev.begin(), s_rev.end());
p_pow[0] = 1;
for (int i = 1; i <= n; ++i) {
p_pow[i] = p_pow[i - 1] * P;
h_fwd[i] = h_fwd[i - 1] * P + (s[i - 1] - 'a' + 1);
h_rev[i] = h_rev[i - 1] * P + (s_rev[i - 1] - 'a' + 1);
}
int count = 0;
for (int i = 1; i <= n - k + 1; ++i) {
// 子串 s[i..i+k-1]
int l1 = i, r1 = i + k - 1;
// 对应的反转串中的子串
int l2 = n - (i + k - 1) + 1, r2 = n - i + 1;
if (get_hash(h_fwd, l1, r1) == get_hash(h_rev, l2, r2)) {
count++;
}
}
cout << count << endl;
return 0;
}
import java.util.Scanner;
public class Main {
private static final long P1 = 31, P2 = 37;
private static final long M1 = 1_000_000_007, M2 = 1_000_000_009;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int k = sc.nextInt();
String s = sc.next();
String sRev = new StringBuilder(s).reverse().toString();
long[] p1Pow = new long[n + 1];
long[] p2Pow = new long[n + 1];
p1Pow[0] = 1;
p2Pow[0] = 1;
long[][] hFwd = new long[n + 1][2];
long[][] hRev = new long[n + 1][2];
for (int i = 1; i <= n; i++) {
p1Pow[i] = (p1Pow[i - 1] * P1) % M1;
p2Pow[i] = (p2Pow[i - 1] * P2) % M2;
hFwd[i][0] = (hFwd[i - 1][0] * P1 + (s.charAt(i - 1) - 'a' + 1)) % M1;
hFwd[i][1] = (hFwd[i - 1][1] * P2 + (s.charAt(i - 1) - 'a' + 1)) % M2;
hRev[i][0] = (hRev[i - 1][0] * P1 + (sRev.charAt(i - 1) - 'a' + 1)) % M1;
hRev[i][1] = (hRev[i - 1][1] * P2 + (sRev.charAt(i - 1) - 'a' + 1)) % M2;
}
int count = 0;
for (int i = 0; i <= n - k; i++) {
long[] hashFwd = getHash(hFwd, p1Pow, p2Pow, i, i + k - 1, k);
int revI = n - (i + k - 1) - 1;
long[] hashRev = getHash(hRev, p1Pow, p2Pow, revI, revI + k - 1, k);
if (hashFwd[0] == hashRev[0] && hashFwd[1] == hashRev[1]) {
count++;
}
}
System.out.println(count);
}
private static long[] getHash(long[][] h, long[] p1Pow, long[] p2Pow, int l, int r, int k) {
long hash1 = (h[r + 1][0] - (h[l][0] * p1Pow[k]) % M1 + M1) % M1;
long hash2 = (h[r + 1][1] - (h[l][1] * p2Pow[k]) % M2 + M2) % M2;
return new long[]{hash1, hash2};
}
}
import sys
def solve():
n, k = map(int, sys.stdin.readline().split())
s = sys.stdin.readline().strip()
s_rev = s[::-1]
P1, M1 = 31, 10**9 + 7
P2, M2 = 37, 10**9 + 9
p1_pow = [1] * (n + 1)
p2_pow = [1] * (n + 1)
for i in range(1, n + 1):
p1_pow[i] = (p1_pow[i - 1] * P1) % M1
p2_pow[i] = (p2_pow[i - 1] * P2) % M2
h_fwd = [(0, 0)] * (n + 1)
h_rev = [(0, 0)] * (n + 1)
for i in range(n):
val = ord(s[i]) - ord('a') + 1
h_fwd[i+1] = (
(h_fwd[i][0] * P1 + val) % M1,
(h_fwd[i][1] * P2 + val) % M2
)
val_rev = ord(s_rev[i]) - ord('a') + 1
h_rev[i+1] = (
(h_rev[i][0] * P1 + val_rev) % M1,
(h_rev[i][1] * P2 + val_rev) % M2
)
def get_hash(h, l, r):
len_sub = r - l + 1
h1 = (h[r+1][0] - (h[l][0] * p1_pow[len_sub]) % M1 + M1) % M1
h2 = (h[r+1][1] - (h[l][1] * p2_pow[len_sub]) % M2 + M2) % M2
return h1, h2
count = 0
for i in range(n - k + 1):
hash_fwd = get_hash(h_fwd, i, i + k - 1)
rev_i = n - (i + k - 1) - 1
hash_rev = get_hash(h_rev, rev_i, rev_i + k - 1)
if hash_fwd == hash_rev:
count += 1
print(count)
solve()
算法及复杂度
-
算法:字符串双哈希(滚动哈希)
-
时间复杂度:
。预计算哈希数组和幂的数组需要
,之后遍历所有子串并进行
的哈希比较,总共需要
。
-
空间复杂度:
,用于存储预计算的哈希值和幂。