题目链接
题目描述
给定两个长度为 的字符串
和
。对于每一个
,我们通过翻转
的前
个字符得到一个新的字符串
。任务是找到所有
中,使得
和
的最长公共前缀 (LCP) 最长的那个,并输出这个最长的 LCP 长度以及达到该长度的最小的
。
解题思路
这是一个可以通过字符串哈希 (String Hashing) 和二分查找 (Binary Search) 高效解决的问题。
对于每个翻转长度 (从
到
),我们都会生成一个新的字符串
,并需要计算它与字符串
的最长公共前缀 (LCP)。
的构成是:
的前
个字符翻转后的结果,拼接上
剩下的
个字符。
- 即
(使用1-based索引)。
一个一个地生成 再去比较 LCP 会导致
的复杂度,对于
的范围来说太慢了。我们需要一个更快的方法来计算或比较前缀。
字符串哈希提供了一个 的方法(在
预处理后)来比较任意两个子串是否相等。我们可以利用这一点。
对于每个 ,我们想找到最大的
使得
。这个
是单调的(如果长度为
的前缀匹配,那么所有长度小于
的前缀也都匹配),因此我们可以对
进行二分查找。
算法步骤
-
预处理 (
):
- 为了处理字符串哈希,我们需要选择一个质数
base
和一个大质数mod
。 - 计算字符串
、字符串
、以及
的反转串
reverse(S)
的前缀哈希值数组。同时预计算base
的各次幂。 - 这样,我们就可以在
时间内查询任意子串的哈希值。
- 为了处理字符串哈希,我们需要选择一个质数
-
主循环 (
):
- 遍历所有可能的翻转长度
从
到
。
- 对于每个
,二分查找
的长度
,范围是
。
check(L)
函数:- 该函数判断
的前
个字符是否与
的前
个字符完全相同。
- 如果
: 我们需要比较
和
。这可以通过比较
reverse(S)
的一个子串和的一个前缀的哈希值来完成。
- 如果
: 我们需要同时满足两个条件:
(第一部分完全匹配)。
(第二部分的延伸匹配)。 这两个条件都可以通过
的哈希比较来验证。
- 该函数判断
- 二分查找结束后,我们得到了当前
对应的 LCP 长度。我们用它来更新全局的最大 LCP 和对应的最小
。
- 遍历所有可能的翻转长度
-
输出结果:
- 循环结束后,输出记录下的最大 LCP 和最小
。
- 循环结束后,输出记录下的最大 LCP 和最小
代码
#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
using namespace std;
const long long BASE = 131; // A common prime for hashing
const long long MOD = 1e9 + 7;
struct StringHasher {
vector<long long> h, p;
StringHasher(const string& s) {
int n = s.length();
h.resize(n + 1, 0);
p.resize(n + 1, 1);
for (int i = 0; i < n; i++) {
p[i + 1] = (p[i] * BASE) % MOD;
h[i + 1] = (h[i] * BASE + s[i]) % MOD;
}
}
long long get_hash(int l, int r) { // 1-indexed, inclusive [l, r]
long long res = (h[r] - (h[l - 1] * p[r - l + 1]) % MOD + MOD) % MOD;
return res;
}
};
void solve() {
int n;
cin >> n;
string s, t;
cin >> s >> t;
string s_rev = s;
reverse(s_rev.begin(), s_rev.end());
StringHasher hs(s), ht(t), hs_rev(s_rev);
int max_lcp = 0;
int best_k = 1;
for (int k = 1; k <= n; k++) {
int low = 0, high = n, current_lcp = 0;
while (low <= high) {
int mid = low + (high - low) / 2;
if (mid == 0) {
low = mid + 1;
continue;
}
bool ok = false;
if (mid <= k) {
// hash of reverse(s[1..mid]) vs hash of t[1..mid]
// reverse(s[1..mid]) is a substring of s_rev
long long h_s_part = hs_rev.get_hash(n - k + 1, n - k + mid);
long long h_t_part = ht.get_hash(1, mid);
if (h_s_part == h_t_part) ok = true;
} else { // mid > k
// Check first k chars
long long h_s_part1 = hs_rev.get_hash(n - k + 1, n);
long long h_t_part1 = ht.get_hash(1, k);
if (h_s_part1 == h_t_part1) {
// Check remaining mid-k chars
long long h_s_part2 = hs.get_hash(k + 1, mid);
long long h_t_part2 = ht.get_hash(k + 1, mid);
if (h_s_part2 == h_t_part2) ok = true;
}
}
if (ok) {
current_lcp = mid;
low = mid + 1;
} else {
high = mid - 1;
}
}
if (current_lcp > max_lcp) {
max_lcp = current_lcp;
best_k = k;
}
}
cout << max_lcp << " " << best_k << endl;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int t;
cin >> t;
while (t--) {
solve();
}
return 0;
}
import java.io.*;
import java.util.*;
public class Main {
static class StringHasher {
long[] h, p;
static final long BASE = 131;
static final long MOD = (long)1e9 + 7;
StringHasher(String s) {
int n = s.length();
h = new long[n + 1];
p = new long[n + 1];
p[0] = 1;
for (int i = 0; i < n; i++) {
p[i + 1] = (p[i] * BASE) % MOD;
h[i + 1] = (h[i] * BASE + s.charAt(i)) % MOD;
}
}
long getHash(int l, int r) { // 1-indexed, inclusive [l, r]
long res = (h[r] - (h[l - 1] * p[r - l + 1]) % MOD + MOD) % MOD;
return res;
}
}
static void solve(BufferedReader br, PrintWriter out) throws IOException {
int n = Integer.parseInt(br.readLine());
String s = br.readLine();
String t = br.readLine();
String sRev = new StringBuilder(s).reverse().toString();
StringHasher hs = new StringHasher(s);
StringHasher ht = new StringHasher(t);
StringHasher hsRev = new StringHasher(sRev);
int maxLcp = 0;
int bestK = 1;
for (int k = 1; k <= n; k++) {
int low = 0, high = n, currentLcp = 0;
while (low <= high) {
int mid = low + (high - low) / 2;
if (mid == 0) {
low = mid + 1;
continue;
}
boolean ok = false;
if (mid <= k) {
long hSPart = hsRev.getHash(n - k + 1, n - k + mid);
long hTPart = ht.getHash(1, mid);
if (hSPart == hTPart) ok = true;
} else {
long hSPart1 = hsRev.getHash(n - k + 1, n);
long hTPart1 = ht.getHash(1, k);
if (hSPart1 == hTPart1) {
long hSPart2 = hs.getHash(k + 1, mid);
long hTPart2 = ht.getHash(k + 1, mid);
if (hSPart2 == hTPart2) ok = true;
}
}
if (ok) {
currentLcp = mid;
low = mid + 1;
} else {
high = mid - 1;
}
}
if (currentLcp > maxLcp) {
maxLcp = currentLcp;
bestK = k;
}
}
out.println(maxLcp + " " + bestK);
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
PrintWriter out = new PrintWriter(System.out);
int t = Integer.parseInt(br.readLine());
while (t-- > 0) {
solve(br, out);
}
out.flush();
}
}
import sys
BASE = 131
MOD = 10**9 + 7
class StringHasher:
def __init__(self, s):
n = len(s)
self.h = [0] * (n + 1)
self.p = [1] * (n + 1)
for i in range(n):
self.p[i + 1] = (self.p[i] * BASE) % MOD
self.h[i + 1] = (self.h[i] * BASE + ord(s[i])) % MOD
def get_hash(self, l, r): # 1-indexed, inclusive [l, r]
if l > r: return 0
res = (self.h[r] - (self.h[l - 1] * self.p[r - l + 1]) % MOD + MOD) % MOD
return res
def solve():
try:
n_str = sys.stdin.readline()
if not n_str: return
n = int(n_str)
s = sys.stdin.readline().strip()
t = sys.stdin.readline().strip()
except (IOError, ValueError):
return
s_rev = s[::-1]
hs = StringHasher(s)
ht = StringHasher(t)
hs_rev = StringHasher(s_rev)
max_lcp = 0
best_k = 1
for k in range(1, n + 1):
low, high, current_lcp = 0, n, 0
while low <= high:
mid = (low + high) // 2
if mid == 0:
low = mid + 1
continue
ok = False
if mid <= k:
h_s_part = hs_rev.get_hash(n - k + 1, n - k + mid)
h_t_part = ht.get_hash(1, mid)
if h_s_part == h_t_part:
ok = True
else: # mid > k
h_s_part1 = hs_rev.get_hash(n - k + 1, n)
h_t_part1 = ht.get_hash(1, k)
if h_s_part1 == h_t_part1:
h_s_part2 = hs.get_hash(k + 1, mid)
h_t_part2 = ht.get_hash(k + 1, mid)
if h_s_part2 == h_t_part2:
ok = True
if ok:
current_lcp = mid
low = mid + 1
else:
high = mid - 1
if current_lcp > max_lcp:
max_lcp = current_lcp
best_k = k
sys.stdout.write(f"{max_lcp} {best_k}\n")
def main():
try:
t_str = sys.stdin.readline()
if not t_str: return
t = int(t_str)
for _ in range(t):
solve()
except (IOError, ValueError):
return
main()
算法及复杂度
- 算法:字符串哈希 + 二分查找
- 时间复杂度:
- 哈希预处理需要
。
- 外层循环遍历
从
到
。
- 内层对 LCP 长度进行二分查找,需要
次。
check
函数内部的哈希比较是的。
- 总时间复杂度为
。
- 哈希预处理需要
- 空间复杂度:
- 需要存储
,
,
reverse(S)
的前缀哈希数组和base
的幂,均为。
- 需要存储