PEEK139 固定长度的回文子串计数

题目链接

PEEK139 固定长度的回文子串计数

题目描述

给定一个由小写字母组成的字符串 S,长度为 N,以及一个正整数 L。请计算 S 中有多少个长度为 L 的子串是回文串。

解题思路

本题要求我们统计一个给定字符串 S 中,所有长度为 L 的回文子串的数量。

一个朴素的解法是遍历所有长度为 L 的子串。字符串 S 中共有 N - L + 1 个这样的子串。对于每个子串,我们再花费 的时间来检查它是否为回文串。这种方法的总时间复杂度为 ,在 N 很大的情况下,效率较低,可能会超时。

为了优化回文判断的效率,我们可以采用字符串哈希。其核心思想是:一个字符串是回文串,当且仅当它的正序字符串与其反序字符串完全相同。利用哈希函数,我们可以将这个比较过程从 优化到

具体步骤如下:

  1. 预处理哈希值

    • 为了能在 时间内获取任意子串的哈希值,我们需要预处理原字符串 S正向哈希数组(h)和反向哈希数组(rh)。
    • 正向哈希数组 h[i] 存储 S 的前缀 S[0...i-1] 的哈希值。
    • 反向哈希数组 rh[i] 存储 S 的后缀 S[i...N-1] 的哈希值。这等价于 S 的反转字符串的前缀哈希。
    • 同时,我们需要预计算哈希基数(BASE)的幂,以便后续计算。
  2. O(1) 回文判断

    • 对于任意一个从索引 i 开始,长度为 L 的子串 S[i...i+L-1](0-indexed),我们可以 计算出它的正向哈希值。
    • 我们也可以 计算出这个子串的反向哈希值。
    • 如果这两个哈希值相等,我们就认为这个子串是回文串。
  3. 遍历与计数

    • 我们遍历所有可能的子串起始位置 i,从 0N - L
    • 在每次迭代中,我们执行上述的 回文判断。
    • 如果判断为真,则将计数器加一。
  4. 双哈希

    • 为了最大限度地避免哈希碰撞(即两个不同的字符串产生相同的哈希值),我们可以采用双哈希技术。这意味着我们选择两个不同的质数作为哈希基数(BASE1BASE2),并计算两套独立的哈希值。只有当一个子串的两套正向、反向哈希值都分别相等时,我们才认定它是回文串。

通过这种方法,我们将总时间复杂度从 优化到了 (预处理)+ (遍历计数),即 ,可以高效地解决问题。

代码

#include <iostream>
#include <string>
#include <vector>
#include <algorithm>

using namespace std;

typedef unsigned long long ull;

struct StringHasher {
    ull BASE1 = 131;
    ull BASE2 = 13331;
    vector<ull> h1, p1, rh1;
    vector<ull> h2, p2, rh2;
    int n;

    StringHasher(const string& s) {
        n = s.length();
        p1.resize(n + 1, 1);
        h1.resize(n + 1, 0);
        rh1.resize(n + 2, 0);
        p2.resize(n + 1, 1);
        h2.resize(n + 1, 0);
        rh2.resize(n + 2, 0);

        for (int i = 1; i <= n; ++i) {
            p1[i] = p1[i - 1] * BASE1;
            h1[i] = h1[i - 1] * BASE1 + s[i - 1];
            p2[i] = p2[i - 1] * BASE2;
            h2[i] = h2[i - 1] * BASE2 + s[i - 1];
        }
        for (int i = n; i >= 1; --i) {
            rh1[i] = rh1[i + 1] * BASE1 + s[i - 1];
            rh2[i] = rh2[i + 1] * BASE2 + s[i - 1];
        }
    }

    // 获取 S[l..r] (1-indexed) 的正向哈希值
    pair<ull, ull> get_fwd_hash(int l, int r) {
        ull hash1 = h1[r] - h1[l - 1] * p1[r - l + 1];
        ull hash2 = h2[r] - h2[l - 1] * p2[r - l + 1];
        return {hash1, hash2};
    }

    // 获取 S[l..r] (1-indexed) 的反向哈希值
    pair<ull, ull> get_rev_hash(int l, int r) {
        ull hash1 = rh1[l] - rh1[r + 1] * p1[r - l + 1];
        ull hash2 = rh2[l] - rh2[r + 1] * p2[r - l + 1];
        return {hash1, hash2};
    }
};

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, l;
    cin >> n >> l;
    string s;
    cin >> s;

    if (l > n) {
        cout << 0 << '\n';
        return 0;
    }

    StringHasher hasher(s);
    int count = 0;
    for (int i = 1; i <= n - l + 1; ++i) {
        if (hasher.get_fwd_hash(i, i + l - 1) == hasher.get_rev_hash(i, i + l - 1)) {
            count++;
        }
    }

    cout << count << '\n';

    return 0;
}
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;

public class Main {
    static class StringHasher {
        long BASE1 = 131;
        long BASE2 = 13331;
        long[] h1, p1, rh1;
        long[] h2, p2, rh2;
        int n;

        public StringHasher(String s) {
            n = s.length();
            p1 = new long[n + 1];
            h1 = new long[n + 1];
            rh1 = new long[n + 2];
            p2 = new long[n + 1];
            h2 = new long[n + 1];
            rh2 = new long[n + 2];

            p1[0] = 1;
            p2[0] = 1;

            for (int i = 1; i <= n; ++i) {
                p1[i] = p1[i - 1] * BASE1;
                h1[i] = h1[i - 1] * BASE1 + s.charAt(i - 1);
                p2[i] = p2[i - 1] * BASE2;
                h2[i] = h2[i - 1] * BASE2 + s.charAt(i - 1);
            }
            for (int i = n; i >= 1; --i) {
                rh1[i] = rh1[i + 1] * BASE1 + s.charAt(i - 1);
                rh2[i] = rh2[i + 1] * BASE2 + s.charAt(i - 1);
            }
        }

        public long getFwdHash1(int l, int r) {
            return h1[r] - h1[l - 1] * p1[r - l + 1];
        }
        public long getFwdHash2(int l, int r) {
            return h2[r] - h2[l - 1] * p2[r - l + 1];
        }

        public long getRevHash1(int l, int r) {
            return rh1[l] - rh1[r + 1] * p1[r - l + 1];
        }
        public long getRevHash2(int l, int r) {
            return rh2[l] - rh2[r + 1] * p2[r - l + 1];
        }
        
        public boolean isPalindrome(int l, int r) {
            return getFwdHash1(l, r) == getRevHash1(l, r) && getFwdHash2(l, r) == getRevHash2(l, r);
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] parts = br.readLine().split(" ");
        int n = Integer.parseInt(parts[0]);
        int l = Integer.parseInt(parts[1]);
        String s = br.readLine();

        if (l > n) {
            System.out.println(0);
            return;
        }

        StringHasher hasher = new StringHasher(s);
        int count = 0;
        for (int i = 1; i <= n - l + 1; ++i) {
            if (hasher.isPalindrome(i, i + l - 1)) {
                count++;
            }
        }
        System.out.println(count);
    }
}
import sys

class StringHasher:
    def __init__(self, s):
        self.n = len(s)
        self.BASE1, self.MOD1 = 131, 10**9 + 7
        self.BASE2, self.MOD2 = 13331, 10**9 + 9

        self.p1 = [1] * (self.n + 1)
        self.h1 = [0] * (self.n + 1)
        self.rh1 = [0] * (self.n + 2)

        self.p2 = [1] * (self.n + 1)
        self.h2 = [0] * (self.n + 1)
        self.rh2 = [0] * (self.n + 2)

        for i in range(1, self.n + 1):
            self.p1[i] = (self.p1[i - 1] * self.BASE1) % self.MOD1
            self.h1[i] = (self.h1[i - 1] * self.BASE1 + ord(s[i - 1])) % self.MOD1
            self.p2[i] = (self.p2[i - 1] * self.BASE2) % self.MOD2
            self.h2[i] = (self.h2[i - 1] * self.BASE2 + ord(s[i - 1])) % self.MOD2

        for i in range(self.n, 0, -1):
            self.rh1[i] = (self.rh1[i + 1] * self.BASE1 + ord(s[i - 1])) % self.MOD1
            self.rh2[i] = (self.rh2[i + 1] * self.BASE2 + ord(s[i - 1])) % self.MOD2

    def get_fwd_hash(self, l, r):
        len_ = r - l + 1
        hash1 = (self.h1[r] - self.h1[l - 1] * self.p1[len_]) % self.MOD1
        hash2 = (self.h2[r] - self.h2[l - 1] * self.p2[len_]) % self.MOD2
        return (hash1 + self.MOD1) % self.MOD1, (hash2 + self.MOD2) % self.MOD2

    def get_rev_hash(self, l, r):
        len_ = r - l + 1
        hash1 = (self.rh1[l] - self.rh1[r + 1] * self.p1[len_]) % self.MOD1
        hash2 = (self.rh2[l] - self.rh2[r + 1] * self.p2[len_]) % self.MOD2
        return (hash1 + self.MOD1) % self.MOD1, (hash2 + self.MOD2) % self.MOD2

def main():
    try:
        n, l = map(int, sys.stdin.readline().split())
        s = sys.stdin.readline().strip()
    except (IOError, ValueError):
        return

    if l > n:
        print(0)
        return

    hasher = StringHasher(s)
    count = 0
    for i in range(1, n - l + 2):
        if hasher.get_fwd_hash(i, i + l - 1) == hasher.get_rev_hash(i, i + l - 1):
            count += 1
    
    print(count)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:字符串哈希(双哈希)。
  • 时间复杂度:预处理哈希数组需要 的时间。之后,遍历所有 N - L + 1 个可能的子串,对每个子串的判断是 。因此,总时间复杂度为
  • 空间复杂度:需要 的空间来存储哈希数组和基数的幂。