题目链接

小红的字符串匹配

题目描述

给定一个主字符串 和一个长度限制 。对于一个查询字符串 ,如果 满足以下两个条件之一,则称小红喜欢

  1. 的某一个长度不小于 的前缀是 的一个子串。
  2. 的某一个长度不小于 的后缀是 的一个子串。

现在有 次询问,每次给定一个字符串 ,需要判断小红是否喜欢它。

解题思路

本题的核心是需要对每个查询字符串 ,快速判断其一系列前缀和后缀是否存在于主字符串 中。由于查询总长度很大,对每个前缀/后缀都进行独立的子串搜索(如 KMP)是不可行的。我们需要一个能够一次性处理所有子串查询的数据结构。

数据结构选择:后缀自动机 (SAM)

后缀自动机 (Suffix Automaton, SAM) 是处理字符串子串问题的强大工具。为一个字符串 构建 SAM 后,我们可以实现以下功能:

  • 的时间内判断任意字符串 是否为 的子串。
  • SAM 的状态数和转移数都是 级别的,空间开销合理。

简化查询条件

原始的查询条件是“是否存在一个长度不小于 的前缀/后缀...”。我们可以将其简化:

  1. 前缀条件:如果存在一个长度为 的前缀是 的子串,那么我们只需要找到**《 的最长前缀》中,同时也是 的子串的那个**,然后检查其长度是否大于等于 即可。如果这个最长的前缀满足条件,那么就找到了一个;如果连它都不满足,更短的前缀(只要长度还大于等于k)也不会改变结果,而更长的前缀已经不是S的子串了。
  2. 后缀条件:同理,我们只需要找到**《 的最长后缀》中,同时也是 的子串的那个**,然后检查其长度是否大于等于

如何高效检查后缀?

使用 SAM 可以轻松地在 时间内找到 的最长前缀,该前缀也是 的子串。但如何处理后缀呢? 一个巧妙的技巧是利用字符串反转的性质:一个字符串 的后缀,就是其反转串 的前缀。 因此,要找 的最长后缀(同时是 的子串),等价于找 的最长前缀(同时是 的子串)。

最终算法

  1. 预处理阶段:

    • 为原始主字符串 构建一个后缀自动机,记为 sam_s
    • 反转得到 ,并为 构建第二个后缀自动机,记为 sam_s_rev
    • 整个预处理的时间和空间复杂度均为
  2. 查询阶段:

    • 对于每一个查询字符串
    • 检查前缀:使用 sam_s,通过简单的节点遍历,找到 的最长前缀(同时是 的子串)的长度,记为 max_prefix_len。这个过程耗时
    • 如果 max_prefix_len >= k,则说明小红喜欢 。输出 "YES" 并处理下一个查询。
    • 检查后缀:如果前缀不满足条件,则将 反转得到
    • 使用 sam_s_rev,找到 的最长前缀(同时是 的子串)的长度,记为 max_suffix_len。这等价于 的最长后缀(同时是 的子串)的长度。这个过程也耗时
    • 如果 max_suffix_len >= k,则说明小红喜欢 。输出 "YES"。
    • 如果前后缀都不满足条件,输出 "NO"。

这个算法的总时间复杂度为 ,能够轻松通过本题的数据范围。

代码

#include <iostream>
#include <string>
#include <vector>
#include <map>
#include <algorithm>

using namespace std;

const int MAX_LEN = 200005; // |S| <= 10^5, so 2*|S| is enough for SAM states

struct SAM_Node {
    int len, link;
    map<char, int> next;
};

SAM_Node sam[MAX_LEN * 2];
int sz, last;

void sam_init() {
    sz = 1;
    last = 0;
    sam[0].len = 0;
    sam[0].link = -1;
    sam[0].next.clear();
}

void sam_extend(char c) {
    int cur = sz++;
    sam[cur].len = sam[last].len + 1;
    sam[cur].next.clear();
    int p = last;
    while (p != -1 && sam[p].next.find(c) == sam[p].next.end()) {
        sam[p].next[c] = cur;
        p = sam[p].link;
    }
    if (p == -1) {
        sam[cur].link = 0;
    } else {
        int q = sam[p].next[c];
        if (sam[q].len == sam[p].len + 1) {
            sam[cur].link = q;
        } else {
            int clone = sz++;
            sam[clone].len = sam[p].len + 1;
            sam[clone].next = sam[q].next;
            sam[clone].link = sam[q].link;
            while (p != -1 && sam[p].next[c] == q) {
                sam[p].next[c] = clone;
                p = sam[p].link;
            }
            sam[q].link = clone;
            sam[cur].link = clone;
        }
    }
    last = cur;
}

int get_longest_match(const string& t) {
    int v = 0, l = 0;
    for (char c : t) {
        if (sam[v].next.count(c)) {
            v = sam[v].next[c];
            l++;
        } else {
            break;
        }
    }
    return l;
}


int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    string s;
    cin >> s;
    int q, k;
    cin >> q >> k;

    // Build SAM for S
    sam_init();
    for (char c : s) {
        sam_extend(c);
    }
    vector<SAM_Node> sam_s(sam, sam + sz);
    
    // Build SAM for reversed S
    reverse(s.begin(), s.end());
    sam_init();
    for (char c : s) {
        sam_extend(c);
    }
    vector<SAM_Node> sam_s_rev(sam, sam + sz);


    while (q--) {
        string t;
        cin >> t;

        // Check prefix
        int v = 0, l = 0;
        for (char c : t) {
            if (sam_s[v].next.count(c)) {
                v = sam_s[v].next.at(c);
                l++;
            } else {
                break;
            }
        }
        
        if (l >= k) {
            cout << "YES\n";
            continue;
        }

        // Check suffix
        reverse(t.begin(), t.end());
        v = 0, l = 0;
        for (char c : t) {
            if (sam_s_rev[v].next.count(c)) {
                v = sam_s_rev[v].next.at(c);
                l++;
            } else {
                break;
            }
        }

        if (l >= k) {
            cout << "YES\n";
        } else {
            cout << "NO\n";
        }
    }

    return 0;
}
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;
import java.util.StringTokenizer;

public class Main {

    static class SAM {
        static class Node {
            int len, link;
            Map<Character, Integer> next = new HashMap<>();
        }

        Node[] nodes;
        int sz, last;

        public SAM(int maxLen) {
            // Allocate enough space for nodes
            nodes = new Node[maxLen * 2 + 2];
            for (int i = 0; i < maxLen * 2 + 2; i++) {
                nodes[i] = new Node();
            }
            nodes[0].len = 0;
            nodes[0].link = -1;
            sz = 1;
            last = 0;
        }

        public void extend(char c) {
            int cur = sz++;
            nodes[cur].len = nodes[last].len + 1;
            int p = last;
            while (p != -1 && !nodes[p].next.containsKey(c)) {
                nodes[p].next.put(c, cur);
                p = nodes[p].link;
            }
            if (p == -1) {
                nodes[cur].link = 0;
            } else {
                int q = nodes[p].next.get(c);
                if (nodes[q].len == nodes[p].len + 1) {
                    nodes[cur].link = q;
                } else {
                    int clone = sz++;
                    nodes[clone].len = nodes[p].len + 1;
                    nodes[clone].next.putAll(nodes[q].next);
                    nodes[clone].link = nodes[q].link;
                    while (p != -1 && nodes[p].next.get(c) == q) {
                        nodes[p].next.put(c, clone);
                        p = nodes[p].link;
                    }
                    nodes[q].link = clone;
                    nodes[cur].link = clone;
                }
            }
            last = cur;
        }

        public int getLongestMatch(String t) {
            int v = 0, l = 0;
            for (char c : t.toCharArray()) {
                if (nodes[v].next.containsKey(c)) {
                    v = nodes[v].next.get(c);
                    l++;
                } else {
                    break;
                }
            }
            return l;
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String s = br.readLine();
        StringTokenizer st = new StringTokenizer(br.readLine());
        int q = Integer.parseInt(st.nextToken());
        int k = Integer.parseInt(st.nextToken());

        SAM sam_s = new SAM(s.length());
        for (char c : s.toCharArray()) {
            sam_s.extend(c);
        }

        String s_rev = new StringBuilder(s).reverse().toString();
        SAM sam_s_rev = new SAM(s.length());
        for (char c : s_rev.toCharArray()) {
            sam_s_rev.extend(c);
        }

        StringBuilder output = new StringBuilder();
        for (int i = 0; i < q; i++) {
            String t = br.readLine();
            if (t == null || t.isEmpty()) continue;
            
            if (sam_s.getLongestMatch(t) >= k) {
                output.append("YES\n");
                continue;
            }

            String t_rev = new StringBuilder(t).reverse().toString();
            if (sam_s_rev.getLongestMatch(t_rev) >= k) {
                output.append("YES\n");
            } else {
                output.append("NO\n");
            }
        }
        System.out.print(output.toString());
    }
}
import sys

class SuffixAutomaton:
    def __init__(self, max_len):
        self.nodes = [{'len': 0, 'link': -1, 'next': {}} for _ in range(max_len * 2 + 2)]
        self.sz = 1
        self.last = 0

    def extend(self, char):
        cur = self.sz
        self.sz += 1
        self.nodes[cur]['len'] = self.nodes[self.last]['len'] + 1
        
        p = self.last
        while p != -1 and char not in self.nodes[p]['next']:
            self.nodes[p]['next'][char] = cur
            p = self.nodes[p]['link']
        
        if p == -1:
            self.nodes[cur]['link'] = 0
        else:
            q = self.nodes[p]['next'][char]
            if self.nodes[q]['len'] == self.nodes[p]['len'] + 1:
                self.nodes[cur]['link'] = q
            else:
                clone = self.sz
                self.sz += 1
                self.nodes[clone]['len'] = self.nodes[p]['len'] + 1
                self.nodes[clone]['link'] = self.nodes[q]['link']
                self.nodes[clone]['next'] = self.nodes[q]['next'].copy()
                
                while p != -1 and self.nodes[p]['next'].get(char) == q:
                    self.nodes[p]['next'][char] = clone
                    p = self.nodes[p]['link']
                self.nodes[q]['link'] = clone
                self.nodes[cur]['link'] = clone
        self.last = cur

    def get_longest_match(self, t):
        v, l = 0, 0
        for char in t:
            if char in self.nodes[v]['next']:
                v = self.nodes[v]['next'][char]
                l += 1
            else:
                break
        return l

def solve():
    s = sys.stdin.readline().strip()
    q, k = map(int, sys.stdin.readline().split())

    sam_s = SuffixAutomaton(len(s))
    for char in s:
        sam_s.extend(char)

    s_rev = s[::-1]
    sam_s_rev = SuffixAutomaton(len(s_rev))
    for char in s_rev:
        sam_s_rev.extend(char)
    
    # Use readlines() to handle all queries at once
    lines = sys.stdin.readlines()
    for t in lines:
        t = t.strip()
        if not t: continue
        
        if sam_s.get_longest_match(t) >= k:
            print("YES")
            continue
        
        t_rev = t[::-1]
        if sam_s_rev.get_longest_match(t_rev) >= k:
            print("YES")
        else:
            print("NO")

solve()

算法及复杂度

  • 算法:后缀自动机 (SAM)
  • 时间复杂度
    • 构建两个 SAM 分别需要 的时间。
    • 对于每个查询字符串 ,我们需要遍历它两次(一次正序,一次逆序),每次遍历的复杂度是 。所有查询的总复杂度是
  • 空间复杂度
    • 两个 SAM 的状态和转移总数都是线性的,与 成正比。查询字符串的空间可以复用。