题目链接

不是烤串故事

题目描述

给定两个长度为 的字符串 。对于每一个 ,我们通过翻转 的前 个字符得到一个新的字符串 。任务是找到所有 中,使得 的最长公共前缀 (LCP) 最长的那个,并输出这个最长的 LCP 长度以及达到该长度的最小的

解题思路

这是一个可以通过字符串哈希 (String Hashing)二分查找 (Binary Search) 高效解决的问题。

对于每个翻转长度 (从 ),我们都会生成一个新的字符串 ,并需要计算它与字符串 的最长公共前缀 (LCP)。

  • 的构成是: 的前 个字符翻转后的结果,拼接上 剩下的 个字符。
  • (使用1-based索引)。

一个一个地生成 再去比较 LCP 会导致 的复杂度,对于 的范围来说太慢了。我们需要一个更快的方法来计算或比较前缀。

字符串哈希提供了一个 的方法(在 预处理后)来比较任意两个子串是否相等。我们可以利用这一点。

对于每个 ,我们想找到最大的 使得 。这个 是单调的(如果长度为 的前缀匹配,那么所有长度小于 的前缀也都匹配),因此我们可以对 进行二分查找

算法步骤

  1. 预处理 ():

    • 为了处理字符串哈希,我们需要选择一个质数 base 和一个大质数 mod
    • 计算字符串 、字符串 、以及 的反转串 reverse(S) 的前缀哈希值数组。同时预计算 base 的各次幂。
    • 这样,我们就可以在 时间内查询任意子串的哈希值。
  2. 主循环 ():

    • 遍历所有可能的翻转长度
    • 对于每个 ,二分查找 的长度 ,范围是
    • check(L) 函数:
      • 该函数判断 的前 个字符是否与 的前 个字符完全相同。
      • 如果 : 我们需要比较 。这可以通过比较 reverse(S) 的一个子串和 的一个前缀的哈希值来完成。
      • 如果 : 我们需要同时满足两个条件:
        1. (第一部分完全匹配)。
        2. (第二部分的延伸匹配)。 这两个条件都可以通过 的哈希比较来验证。
    • 二分查找结束后,我们得到了当前 对应的 LCP 长度。我们用它来更新全局的最大 LCP 和对应的最小
  3. 输出结果:

    • 循环结束后,输出记录下的最大 LCP 和最小

代码

#include <iostream>
#include <string>
#include <vector>
#include <algorithm>

using namespace std;

const long long BASE = 131; // A common prime for hashing
const long long MOD = 1e9 + 7;

struct StringHasher {
    vector<long long> h, p;
    StringHasher(const string& s) {
        int n = s.length();
        h.resize(n + 1, 0);
        p.resize(n + 1, 1);
        for (int i = 0; i < n; i++) {
            p[i + 1] = (p[i] * BASE) % MOD;
            h[i + 1] = (h[i] * BASE + s[i]) % MOD;
        }
    }
    long long get_hash(int l, int r) { // 1-indexed, inclusive [l, r]
        long long res = (h[r] - (h[l - 1] * p[r - l + 1]) % MOD + MOD) % MOD;
        return res;
    }
};

void solve() {
    int n;
    cin >> n;
    string s, t;
    cin >> s >> t;
    string s_rev = s;
    reverse(s_rev.begin(), s_rev.end());

    StringHasher hs(s), ht(t), hs_rev(s_rev);

    int max_lcp = 0;
    int best_k = 1;

    for (int k = 1; k <= n; k++) {
        int low = 0, high = n, current_lcp = 0;
        while (low <= high) {
            int mid = low + (high - low) / 2;
            if (mid == 0) {
                low = mid + 1;
                continue;
            }
            bool ok = false;
            if (mid <= k) {
                // hash of reverse(s[1..mid]) vs hash of t[1..mid]
                // reverse(s[1..mid]) is a substring of s_rev
                long long h_s_part = hs_rev.get_hash(n - k + 1, n - k + mid);
                long long h_t_part = ht.get_hash(1, mid);
                if (h_s_part == h_t_part) ok = true;
            } else { // mid > k
                // Check first k chars
                long long h_s_part1 = hs_rev.get_hash(n - k + 1, n);
                long long h_t_part1 = ht.get_hash(1, k);
                if (h_s_part1 == h_t_part1) {
                    // Check remaining mid-k chars
                    long long h_s_part2 = hs.get_hash(k + 1, mid);
                    long long h_t_part2 = ht.get_hash(k + 1, mid);
                    if (h_s_part2 == h_t_part2) ok = true;
                }
            }
            if (ok) {
                current_lcp = mid;
                low = mid + 1;
            } else {
                high = mid - 1;
            }
        }

        if (current_lcp > max_lcp) {
            max_lcp = current_lcp;
            best_k = k;
        }
    }
    cout << max_lcp << " " << best_k << endl;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int t;
    cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static class StringHasher {
        long[] h, p;
        static final long BASE = 131;
        static final long MOD = (long)1e9 + 7;

        StringHasher(String s) {
            int n = s.length();
            h = new long[n + 1];
            p = new long[n + 1];
            p[0] = 1;
            for (int i = 0; i < n; i++) {
                p[i + 1] = (p[i] * BASE) % MOD;
                h[i + 1] = (h[i] * BASE + s.charAt(i)) % MOD;
            }
        }

        long getHash(int l, int r) { // 1-indexed, inclusive [l, r]
            long res = (h[r] - (h[l - 1] * p[r - l + 1]) % MOD + MOD) % MOD;
            return res;
        }
    }

    static void solve(BufferedReader br, PrintWriter out) throws IOException {
        int n = Integer.parseInt(br.readLine());
        String s = br.readLine();
        String t = br.readLine();
        String sRev = new StringBuilder(s).reverse().toString();

        StringHasher hs = new StringHasher(s);
        StringHasher ht = new StringHasher(t);
        StringHasher hsRev = new StringHasher(sRev);

        int maxLcp = 0;
        int bestK = 1;

        for (int k = 1; k <= n; k++) {
            int low = 0, high = n, currentLcp = 0;
            while (low <= high) {
                int mid = low + (high - low) / 2;
                if (mid == 0) {
                    low = mid + 1;
                    continue;
                }
                boolean ok = false;
                if (mid <= k) {
                    long hSPart = hsRev.getHash(n - k + 1, n - k + mid);
                    long hTPart = ht.getHash(1, mid);
                    if (hSPart == hTPart) ok = true;
                } else {
                    long hSPart1 = hsRev.getHash(n - k + 1, n);
                    long hTPart1 = ht.getHash(1, k);
                    if (hSPart1 == hTPart1) {
                        long hSPart2 = hs.getHash(k + 1, mid);
                        long hTPart2 = ht.getHash(k + 1, mid);
                        if (hSPart2 == hTPart2) ok = true;
                    }
                }
                if (ok) {
                    currentLcp = mid;
                    low = mid + 1;
                } else {
                    high = mid - 1;
                }
            }
            if (currentLcp > maxLcp) {
                maxLcp = currentLcp;
                bestK = k;
            }
        }
        out.println(maxLcp + " " + bestK);
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter out = new PrintWriter(System.out);
        int t = Integer.parseInt(br.readLine());
        while (t-- > 0) {
            solve(br, out);
        }
        out.flush();
    }
}
import sys

BASE = 131
MOD = 10**9 + 7

class StringHasher:
    def __init__(self, s):
        n = len(s)
        self.h = [0] * (n + 1)
        self.p = [1] * (n + 1)
        for i in range(n):
            self.p[i + 1] = (self.p[i] * BASE) % MOD
            self.h[i + 1] = (self.h[i] * BASE + ord(s[i])) % MOD
    
    def get_hash(self, l, r): # 1-indexed, inclusive [l, r]
        if l > r: return 0
        res = (self.h[r] - (self.h[l - 1] * self.p[r - l + 1]) % MOD + MOD) % MOD
        return res

def solve():
    try:
        n_str = sys.stdin.readline()
        if not n_str: return
        n = int(n_str)
        s = sys.stdin.readline().strip()
        t = sys.stdin.readline().strip()
    except (IOError, ValueError):
        return
        
    s_rev = s[::-1]

    hs = StringHasher(s)
    ht = StringHasher(t)
    hs_rev = StringHasher(s_rev)

    max_lcp = 0
    best_k = 1

    for k in range(1, n + 1):
        low, high, current_lcp = 0, n, 0
        while low <= high:
            mid = (low + high) // 2
            if mid == 0:
                low = mid + 1
                continue
            
            ok = False
            if mid <= k:
                h_s_part = hs_rev.get_hash(n - k + 1, n - k + mid)
                h_t_part = ht.get_hash(1, mid)
                if h_s_part == h_t_part:
                    ok = True
            else: # mid > k
                h_s_part1 = hs_rev.get_hash(n - k + 1, n)
                h_t_part1 = ht.get_hash(1, k)
                if h_s_part1 == h_t_part1:
                    h_s_part2 = hs.get_hash(k + 1, mid)
                    h_t_part2 = ht.get_hash(k + 1, mid)
                    if h_s_part2 == h_t_part2:
                        ok = True
            
            if ok:
                current_lcp = mid
                low = mid + 1
            else:
                high = mid - 1

        if current_lcp > max_lcp:
            max_lcp = current_lcp
            best_k = k
            
    sys.stdout.write(f"{max_lcp} {best_k}\n")


def main():
    try:
        t_str = sys.stdin.readline()
        if not t_str: return
        t = int(t_str)
        for _ in range(t):
            solve()
    except (IOError, ValueError):
        return

main()

算法及复杂度

  • 算法:字符串哈希 + 二分查找
  • 时间复杂度:
    • 哈希预处理需要
    • 外层循环遍历
    • 内层对 LCP 长度进行二分查找,需要 次。
    • check 函数内部的哈希比较是 的。
    • 总时间复杂度为
  • 空间复杂度:
    • 需要存储 , , reverse(S) 的前缀哈希数组和 base 的幂,均为