题目链接
题目描述
给定一个字符串 和
组询问。每组询问给出两个整数
和
,要求计算
的第
个前缀
和第
个前缀
的最长公共 border 的长度。
解题思路
本题是 KMP 算法与树上最近公共祖先(LCA)算法的结合。
1. 从 Border 到失配树
一个字符串 的 border 是指既是
的真前缀又是其真后缀的子串。KMP 算法的前缀函数
计算的正是前缀
的最长 border 的长度。
一个重要的性质是:一个前缀 的所有 border 的长度,可以通过
函数的“失配链”得到。即
(这里假设
数组是 1-indexed)。
这个链式关系启发我们建立一棵失配树(Failure Tree 或 Border Tree):
- 建立
个节点,分别代表长度为
的前缀。节点
代表空串,作为树的根。
- 对于前缀
(长度为
,对应节点
),其最长 border 是
(长度为
)。这构成了一种父子关系。
- 因此,我们从节点
向节点
连接一条边。这样,对于所有
,我们就构建了一棵以
为根的树。
2. 从公共 Border 到 LCA
在构建的失配树上:
- 前缀
的所有 border 长度集合,等价于节点
的所有祖先节点(包括
自身)的集合。
- 同理,前缀
的所有 border 长度集合,等价于节点
的所有祖先节点的集合。
因此,“最长公共 border 的长度” 就转化为了 “节点 和
的所有公共祖先中,深度最大的那个节点的编号”。这正是最近公共祖先 (Lowest Common Ancestor, LCA) 的定义。
3. 算法实现:倍增法求 LCA
问题最终变为在失配树上对节点 进行 LCA 查询。
- 计算
数组并同步建树: 在
时间内计算1-indexed的
数组。在计算
的同时,我们知道了节点
的父节点是
,因此可以立即计算出节点
的深度
。这避免了额外的 DFS/BFS 遍历。
- LCA 预处理:
- 将
作为
up[i][0]
(节点的直接父节点)。
- 使用动态规划填充倍增数组
up[i][j]
(节点的第
个祖先),递推式为
up[i][j] = up[up[i][j-1]][j-1]
。时间。
- 将
- 回答查询:对于每组查询
,使用倍增法求
lca(up[p][0], up[q][0])
。时间。
代码
#include <iostream>
#include <vector>
#include <string>
#include <cmath>
#include <algorithm>
using namespace std;
const int MAXN = 1000005;
const int LOGN = 21; // ceil(log2(1000005))
int up[MAXN][LOGN];
int depth[MAXN];
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);
for (int j = LOGN - 1; j >= 0; j--) {
if (depth[u] - (1 << j) >= depth[v]) {
u = up[u][j];
}
}
if (u == v) return u;
for (int j = LOGN - 1; j >= 0; j--) {
if (up[u][j] != up[v][j]) {
u = up[u][j];
v = up[v][j];
}
}
return up[u][0];
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
string s_in;
cin >> s_in;
int n = s_in.length();
string s = " " + s_in; // Convert to 1-indexed
// up[i][0] is the parent of i, which is pi[i]
// depth[0] is 0
for (int i = 2, j = 0; i <= n; i++) {
while (j > 0 && s[i] != s[j + 1]) {
j = up[j][0];
}
if (s[i] == s[j + 1]) {
j++;
}
up[i][0] = j;
depth[i] = depth[j] + 1;
}
for (int j = 1; j < LOGN; j++) {
for (int i = 1; i <= n; i++) {
up[i][j] = up[up[i][j - 1]][j - 1];
}
}
int q;
cin >> q;
while (q--) {
int p, q_val;
cin >> p >> q_val;
cout << lca(up[p][0], up[q_val][0]) << '\n';
}
return 0;
}
import java.io.*;
import java.util.*;
public class Main {
static final int LOGN = 21;
static int[][] up;
static int[] depth;
static int lca(int u, int v) {
if (depth[u] < depth[v]) {
int temp = u; u = v; v = temp;
}
for (int j = LOGN - 1; j >= 0; j--) {
if (depth[u] - (1 << j) >= depth[v]) {
u = up[u][j];
}
}
if (u == v) return u;
for (int j = LOGN - 1; j >= 0; j--) {
if (up[u][j] != up[v][j]) {
u = up[u][j];
v = up[v][j];
}
}
return up[u][0];
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String s_in = br.readLine();
int n = s_in.length();
String s = " " + s_in;
up = new int[n + 1][LOGN];
depth = new int[n + 1];
for (int i = 2, j = 0; i <= n; i++) {
while (j > 0 && s.charAt(i) != s.charAt(j + 1)) {
j = up[j][0];
}
if (s.charAt(i) == s.charAt(j + 1)) {
j++;
}
up[i][0] = j;
depth[i] = depth[j] + 1;
}
for (int j = 1; j < LOGN; j++) {
for (int i = 1; i <= n; i++) {
up[i][j] = up[up[i][j - 1]][j - 1];
}
}
int q = Integer.parseInt(br.readLine());
PrintWriter out = new PrintWriter(System.out);
while (q-- > 0) {
StringTokenizer st = new StringTokenizer(br.readLine());
int p = Integer.parseInt(st.nextToken());
int q_val = Integer.parseInt(st.nextToken());
out.println(lca(up[p][0], up[q_val][0]));
}
out.flush();
}
}
import sys
def solve():
s_in = sys.stdin.readline().strip()
n = len(s_in)
s = " " + s_in
LOGN = (n + 1).bit_length()
up = [[0] * LOGN for _ in range(n + 1)]
depth = [0] * (n + 1)
j = 0
for i in range(2, n + 1):
while j > 0 and s[i] != s[j + 1]:
j = up[j][0]
if s[i] == s[j + 1]:
j += 1
up[i][0] = j
depth[i] = depth[j] + 1
for j in range(1, LOGN):
for i in range(1, n + 1):
up[i][j] = up[up[i][j - 1]][j - 1]
def lca(u, v):
if depth[u] < depth[v]:
u, v = v, u
for j in range(LOGN - 1, -1, -1):
if depth[u] - (1 << j) >= depth[v]:
u = up[u][j]
if u == v:
return u
for j in range(LOGN - 1, -1, -1):
if up[u][j] != up[v][j]:
u = up[u][j]
v = up[v][j]
return up[u][0]
q_str = sys.stdin.readline()
if not q_str: return
q = int(q_str)
lines = sys.stdin.readlines()
for line in lines:
p, q_val = map(int, line.split())
sys.stdout.write(str(lca(up[p][0], up[q_val][0])) + '\n')
solve()
算法及复杂度
- 算法:KMP前缀函数 + 失配树 + 倍增LCA
- 时间复杂度:
- 计算
数组并同步计算深度为
。
- LCA 预处理(倍增表)为
。
次查询,每次查询为
。
- 计算
- 空间复杂度:
- 主要开销在于存储大小为
的倍增
up
数组。
- 主要开销在于存储大小为