题目链接

【模板】失配树

题目描述

给定一个字符串 组询问。每组询问给出两个整数 ,要求计算 的第 个前缀 和第 个前缀 的最长公共 border 的长度。

解题思路

本题是 KMP 算法与树上最近公共祖先(LCA)算法的结合。

1. 从 Border 到失配树

一个字符串 的 border 是指既是 的真前缀又是其真后缀的子串。KMP 算法的前缀函数 计算的正是前缀 最长 border 的长度。

一个重要的性质是:一个前缀 所有 border 的长度,可以通过 函数的“失配链”得到。即 (这里假设 数组是 1-indexed)。

这个链式关系启发我们建立一棵失配树(Failure Tree 或 Border Tree):

  • 建立 个节点,分别代表长度为 的前缀。节点 代表空串,作为树的根。
  • 对于前缀 (长度为 ,对应节点 ),其最长 border 是 (长度为 )。这构成了一种父子关系。
  • 因此,我们从节点 向节点 连接一条边。这样,对于所有 ,我们就构建了一棵以 为根的树。

2. 从公共 Border 到 LCA

在构建的失配树上:

  • 前缀 的所有 border 长度集合,等价于节点 的所有祖先节点(包括 自身)的集合。
  • 同理,前缀 的所有 border 长度集合,等价于节点 的所有祖先节点的集合。

因此,“最长公共 border 的长度” 就转化为了 “节点 的所有公共祖先中,深度最大的那个节点的编号”。这正是最近公共祖先 (Lowest Common Ancestor, LCA) 的定义。

3. 算法实现:倍增法求 LCA

问题最终变为在失配树上对节点 进行 LCA 查询。

  1. 计算 数组并同步建树: 在 时间内计算1-indexed的 数组。在计算 的同时,我们知道了节点 的父节点是 ,因此可以立即计算出节点 的深度 。这避免了额外的 DFS/BFS 遍历。
  2. LCA 预处理
    • 作为 up[i][0](节点 的直接父节点)。
    • 使用动态规划填充倍增数组 up[i][j](节点 的第 个祖先),递推式为 up[i][j] = up[up[i][j-1]][j-1] 时间。
  3. 回答查询:对于每组查询 ,使用倍增法求 lca(up[p][0], up[q][0]) 时间。

代码

#include <iostream>
#include <vector>
#include <string>
#include <cmath>
#include <algorithm>

using namespace std;

const int MAXN = 1000005;
const int LOGN = 21; // ceil(log2(1000005))

int up[MAXN][LOGN];
int depth[MAXN];

int lca(int u, int v) {
    if (depth[u] < depth[v]) swap(u, v);

    for (int j = LOGN - 1; j >= 0; j--) {
        if (depth[u] - (1 << j) >= depth[v]) {
            u = up[u][j];
        }
    }

    if (u == v) return u;

    for (int j = LOGN - 1; j >= 0; j--) {
        if (up[u][j] != up[v][j]) {
            u = up[u][j];
            v = up[v][j];
        }
    }
    return up[u][0];
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    string s_in;
    cin >> s_in;
    int n = s_in.length();
    string s = " " + s_in; // Convert to 1-indexed

    // up[i][0] is the parent of i, which is pi[i]
    // depth[0] is 0
    for (int i = 2, j = 0; i <= n; i++) {
        while (j > 0 && s[i] != s[j + 1]) {
            j = up[j][0];
        }
        if (s[i] == s[j + 1]) {
            j++;
        }
        up[i][0] = j;
        depth[i] = depth[j] + 1;
    }

    for (int j = 1; j < LOGN; j++) {
        for (int i = 1; i <= n; i++) {
            up[i][j] = up[up[i][j - 1]][j - 1];
        }
    }

    int q;
    cin >> q;
    while (q--) {
        int p, q_val;
        cin >> p >> q_val;
        cout << lca(up[p][0], up[q_val][0]) << '\n';
    }

    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static final int LOGN = 21;
    static int[][] up;
    static int[] depth;

    static int lca(int u, int v) {
        if (depth[u] < depth[v]) {
            int temp = u; u = v; v = temp;
        }

        for (int j = LOGN - 1; j >= 0; j--) {
            if (depth[u] - (1 << j) >= depth[v]) {
                u = up[u][j];
            }
        }

        if (u == v) return u;

        for (int j = LOGN - 1; j >= 0; j--) {
            if (up[u][j] != up[v][j]) {
                u = up[u][j];
                v = up[v][j];
            }
        }
        return up[u][0];
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String s_in = br.readLine();
        int n = s_in.length();
        String s = " " + s_in;

        up = new int[n + 1][LOGN];
        depth = new int[n + 1];

        for (int i = 2, j = 0; i <= n; i++) {
            while (j > 0 && s.charAt(i) != s.charAt(j + 1)) {
                j = up[j][0];
            }
            if (s.charAt(i) == s.charAt(j + 1)) {
                j++;
            }
            up[i][0] = j;
            depth[i] = depth[j] + 1;
        }

        for (int j = 1; j < LOGN; j++) {
            for (int i = 1; i <= n; i++) {
                up[i][j] = up[up[i][j - 1]][j - 1];
            }
        }

        int q = Integer.parseInt(br.readLine());
        PrintWriter out = new PrintWriter(System.out);
        while (q-- > 0) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int p = Integer.parseInt(st.nextToken());
            int q_val = Integer.parseInt(st.nextToken());
            out.println(lca(up[p][0], up[q_val][0]));
        }
        out.flush();
    }
}
import sys

def solve():
    s_in = sys.stdin.readline().strip()
    n = len(s_in)
    s = " " + s_in
    
    LOGN = (n + 1).bit_length()
    up = [[0] * LOGN for _ in range(n + 1)]
    depth = [0] * (n + 1)

    j = 0
    for i in range(2, n + 1):
        while j > 0 and s[i] != s[j + 1]:
            j = up[j][0]
        if s[i] == s[j + 1]:
            j += 1
        up[i][0] = j
        depth[i] = depth[j] + 1

    for j in range(1, LOGN):
        for i in range(1, n + 1):
            up[i][j] = up[up[i][j - 1]][j - 1]

    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        
        for j in range(LOGN - 1, -1, -1):
            if depth[u] - (1 << j) >= depth[v]:
                u = up[u][j]
        
        if u == v:
            return u
            
        for j in range(LOGN - 1, -1, -1):
            if up[u][j] != up[v][j]:
                u = up[u][j]
                v = up[v][j]
                
        return up[u][0]

    q_str = sys.stdin.readline()
    if not q_str: return
    q = int(q_str)
    
    lines = sys.stdin.readlines()
    for line in lines:
        p, q_val = map(int, line.split())
        sys.stdout.write(str(lca(up[p][0], up[q_val][0])) + '\n')

solve()

算法及复杂度

  • 算法:KMP前缀函数 + 失配树 + 倍增LCA
  • 时间复杂度:
    • 计算 数组并同步计算深度为
    • LCA 预处理(倍增表)为
    • 次查询,每次查询为
  • 空间复杂度:
    • 主要开销在于存储大小为 的倍增 up 数组。