PEEK139 固定长度的回文子串计数
题目链接
题目描述
给定一个由小写字母组成的字符串 S
,长度为 N
,以及一个正整数 L
。请计算 S
中有多少个长度为 L
的子串是回文串。
解题思路
本题要求我们统计一个给定字符串 S
中,所有长度为 L
的回文子串的数量。
一个朴素的解法是遍历所有长度为 L
的子串。字符串 S
中共有 N - L + 1
个这样的子串。对于每个子串,我们再花费 的时间来检查它是否为回文串。这种方法的总时间复杂度为
,在
N
很大的情况下,效率较低,可能会超时。
为了优化回文判断的效率,我们可以采用字符串哈希。其核心思想是:一个字符串是回文串,当且仅当它的正序字符串与其反序字符串完全相同。利用哈希函数,我们可以将这个比较过程从 优化到
。
具体步骤如下:
-
预处理哈希值:
- 为了能在
时间内获取任意子串的哈希值,我们需要预处理原字符串
S
的正向哈希数组(h
)和反向哈希数组(rh
)。 - 正向哈希数组
h[i]
存储S
的前缀S[0...i-1]
的哈希值。 - 反向哈希数组
rh[i]
存储S
的后缀S[i...N-1]
的哈希值。这等价于S
的反转字符串的前缀哈希。 - 同时,我们需要预计算哈希基数(
BASE
)的幂,以便后续计算。
- 为了能在
-
O(1) 回文判断:
- 对于任意一个从索引
i
开始,长度为L
的子串S[i...i+L-1]
(0-indexed),我们可以计算出它的正向哈希值。
- 我们也可以
计算出这个子串的反向哈希值。
- 如果这两个哈希值相等,我们就认为这个子串是回文串。
- 对于任意一个从索引
-
遍历与计数:
- 我们遍历所有可能的子串起始位置
i
,从0
到N - L
。 - 在每次迭代中,我们执行上述的
回文判断。
- 如果判断为真,则将计数器加一。
- 我们遍历所有可能的子串起始位置
-
双哈希:
- 为了最大限度地避免哈希碰撞(即两个不同的字符串产生相同的哈希值),我们可以采用双哈希技术。这意味着我们选择两个不同的质数作为哈希基数(
BASE1
和BASE2
),并计算两套独立的哈希值。只有当一个子串的两套正向、反向哈希值都分别相等时,我们才认定它是回文串。
- 为了最大限度地避免哈希碰撞(即两个不同的字符串产生相同的哈希值),我们可以采用双哈希技术。这意味着我们选择两个不同的质数作为哈希基数(
通过这种方法,我们将总时间复杂度从 优化到了
(预处理)+
(遍历计数),即
,可以高效地解决问题。
代码
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
using namespace std;
typedef unsigned long long ull;
struct StringHasher {
ull BASE1 = 131;
ull BASE2 = 13331;
vector<ull> h1, p1, rh1;
vector<ull> h2, p2, rh2;
int n;
StringHasher(const string& s) {
n = s.length();
p1.resize(n + 1, 1);
h1.resize(n + 1, 0);
rh1.resize(n + 2, 0);
p2.resize(n + 1, 1);
h2.resize(n + 1, 0);
rh2.resize(n + 2, 0);
for (int i = 1; i <= n; ++i) {
p1[i] = p1[i - 1] * BASE1;
h1[i] = h1[i - 1] * BASE1 + s[i - 1];
p2[i] = p2[i - 1] * BASE2;
h2[i] = h2[i - 1] * BASE2 + s[i - 1];
}
for (int i = n; i >= 1; --i) {
rh1[i] = rh1[i + 1] * BASE1 + s[i - 1];
rh2[i] = rh2[i + 1] * BASE2 + s[i - 1];
}
}
// 获取 S[l..r] (1-indexed) 的正向哈希值
pair<ull, ull> get_fwd_hash(int l, int r) {
ull hash1 = h1[r] - h1[l - 1] * p1[r - l + 1];
ull hash2 = h2[r] - h2[l - 1] * p2[r - l + 1];
return {hash1, hash2};
}
// 获取 S[l..r] (1-indexed) 的反向哈希值
pair<ull, ull> get_rev_hash(int l, int r) {
ull hash1 = rh1[l] - rh1[r + 1] * p1[r - l + 1];
ull hash2 = rh2[l] - rh2[r + 1] * p2[r - l + 1];
return {hash1, hash2};
}
};
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, l;
cin >> n >> l;
string s;
cin >> s;
if (l > n) {
cout << 0 << '\n';
return 0;
}
StringHasher hasher(s);
int count = 0;
for (int i = 1; i <= n - l + 1; ++i) {
if (hasher.get_fwd_hash(i, i + l - 1) == hasher.get_rev_hash(i, i + l - 1)) {
count++;
}
}
cout << count << '\n';
return 0;
}
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
public class Main {
static class StringHasher {
long BASE1 = 131;
long BASE2 = 13331;
long[] h1, p1, rh1;
long[] h2, p2, rh2;
int n;
public StringHasher(String s) {
n = s.length();
p1 = new long[n + 1];
h1 = new long[n + 1];
rh1 = new long[n + 2];
p2 = new long[n + 1];
h2 = new long[n + 1];
rh2 = new long[n + 2];
p1[0] = 1;
p2[0] = 1;
for (int i = 1; i <= n; ++i) {
p1[i] = p1[i - 1] * BASE1;
h1[i] = h1[i - 1] * BASE1 + s.charAt(i - 1);
p2[i] = p2[i - 1] * BASE2;
h2[i] = h2[i - 1] * BASE2 + s.charAt(i - 1);
}
for (int i = n; i >= 1; --i) {
rh1[i] = rh1[i + 1] * BASE1 + s.charAt(i - 1);
rh2[i] = rh2[i + 1] * BASE2 + s.charAt(i - 1);
}
}
public long getFwdHash1(int l, int r) {
return h1[r] - h1[l - 1] * p1[r - l + 1];
}
public long getFwdHash2(int l, int r) {
return h2[r] - h2[l - 1] * p2[r - l + 1];
}
public long getRevHash1(int l, int r) {
return rh1[l] - rh1[r + 1] * p1[r - l + 1];
}
public long getRevHash2(int l, int r) {
return rh2[l] - rh2[r + 1] * p2[r - l + 1];
}
public boolean isPalindrome(int l, int r) {
return getFwdHash1(l, r) == getRevHash1(l, r) && getFwdHash2(l, r) == getRevHash2(l, r);
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String[] parts = br.readLine().split(" ");
int n = Integer.parseInt(parts[0]);
int l = Integer.parseInt(parts[1]);
String s = br.readLine();
if (l > n) {
System.out.println(0);
return;
}
StringHasher hasher = new StringHasher(s);
int count = 0;
for (int i = 1; i <= n - l + 1; ++i) {
if (hasher.isPalindrome(i, i + l - 1)) {
count++;
}
}
System.out.println(count);
}
}
import sys
class StringHasher:
def __init__(self, s):
self.n = len(s)
self.BASE1, self.MOD1 = 131, 10**9 + 7
self.BASE2, self.MOD2 = 13331, 10**9 + 9
self.p1 = [1] * (self.n + 1)
self.h1 = [0] * (self.n + 1)
self.rh1 = [0] * (self.n + 2)
self.p2 = [1] * (self.n + 1)
self.h2 = [0] * (self.n + 1)
self.rh2 = [0] * (self.n + 2)
for i in range(1, self.n + 1):
self.p1[i] = (self.p1[i - 1] * self.BASE1) % self.MOD1
self.h1[i] = (self.h1[i - 1] * self.BASE1 + ord(s[i - 1])) % self.MOD1
self.p2[i] = (self.p2[i - 1] * self.BASE2) % self.MOD2
self.h2[i] = (self.h2[i - 1] * self.BASE2 + ord(s[i - 1])) % self.MOD2
for i in range(self.n, 0, -1):
self.rh1[i] = (self.rh1[i + 1] * self.BASE1 + ord(s[i - 1])) % self.MOD1
self.rh2[i] = (self.rh2[i + 1] * self.BASE2 + ord(s[i - 1])) % self.MOD2
def get_fwd_hash(self, l, r):
len_ = r - l + 1
hash1 = (self.h1[r] - self.h1[l - 1] * self.p1[len_]) % self.MOD1
hash2 = (self.h2[r] - self.h2[l - 1] * self.p2[len_]) % self.MOD2
return (hash1 + self.MOD1) % self.MOD1, (hash2 + self.MOD2) % self.MOD2
def get_rev_hash(self, l, r):
len_ = r - l + 1
hash1 = (self.rh1[l] - self.rh1[r + 1] * self.p1[len_]) % self.MOD1
hash2 = (self.rh2[l] - self.rh2[r + 1] * self.p2[len_]) % self.MOD2
return (hash1 + self.MOD1) % self.MOD1, (hash2 + self.MOD2) % self.MOD2
def main():
try:
n, l = map(int, sys.stdin.readline().split())
s = sys.stdin.readline().strip()
except (IOError, ValueError):
return
if l > n:
print(0)
return
hasher = StringHasher(s)
count = 0
for i in range(1, n - l + 2):
if hasher.get_fwd_hash(i, i + l - 1) == hasher.get_rev_hash(i, i + l - 1):
count += 1
print(count)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:字符串哈希(双哈希)。
- 时间复杂度:预处理哈希数组需要
的时间。之后,遍历所有
N - L + 1
个可能的子串,对每个子串的判断是。因此,总时间复杂度为
。
- 空间复杂度:需要
的空间来存储哈希数组和基数的幂。