PEEK134 【模板】Trie 字典树

题目链接

PEEK134 【模板】Trie 字典树

题目描述

给定 个模式串和 次查询。对于每次查询,给定一个文本串 ,你需要统计有多少个模式串是以 作为前缀的。

解题思路

这是一个经典的使用 Trie 树(字典树或前缀树) 解决的问题。Trie 树是一种专门用于高效存储和检索字符串集合的数据结构,特别擅长处理前缀相关的查询。

Trie 树的结构和性质

  • 它是一个多叉树,树的根节点不代表任何字符。
  • 从根节点到任意一个节点的路径构成一个字符串(通常是某个插入单词的前缀)。
  • 每个节点包含一个指向子节点的指针数组(或哈希表)。数组的大小是字符集的大小(例如,对于小写字母是 26,对于大小写字母是 52)。
  • 为了解决特定问题,节点上通常还会附带一些额外信息。

本题的解法

  1. 构建 Trie 树并预处理计数

    • 我们定义 Trie 树的每个节点。除了指向子节点的指针外,每个节点还需要一个计数器 count。这个 count 的含义是:有多少个插入的模式串经过了这个节点
    • 遍历 个模式串,将它们逐个插入 Trie 树中。
    • 插入过程如下:从根节点开始,对于模式串的每一个字符,找到对应的子节点路径。如果路径不存在,则创建一个新节点。然后,沿着路径移动到子节点,并将当前节点的 count 值加一。
    • 经过这个过程,每个节点 ucount 值就精确地记录了数据集中以“根到 u 的路径所代表的字符串”为前缀的模式串总数。
  2. 处理查询

    • 对于每个查询的文本串 ,我们从 Trie 树的根节点开始,沿着 的字符路径进行匹配。
    • 遍历 的每个字符,在当前节点查找是否存在通向该字符的子节点。
      • 如果存在,就移动到该子节点。
      • 如果不存在,说明没有任何模式串是以 作为前缀的。查询结果为 0,可以直接结束本次查询。
    • 如果成功遍历完 的所有字符,我们会到达一个节点,设为 u。这个节点 u 所代表的路径就是字符串
    • 根据我们之前的定义,u 节点的 count 值就是以 为前缀的模式串总数。我们直接返回这个值即可。

这种方法的优势在于,查询的时间复杂度只与查询串的长度有关,与模式串的数量和长度无关,因此非常高效。

字符集映射: 由于字符串包含大小写英文字母,字符集大小为 52。我们需要一个函数将字符 'a'-'z''A'-'Z' 映射到 0-51 的索引。

  • 'a' - 'z' -> 0 - 25
  • 'A' - 'Z' -> 26 - 51

代码

#include <iostream>
#include <vector>
#include <string>

using namespace std;

const int ALPHABET_SIZE = 52;

// 字符到索引的映射
int char_to_index(char c) {
    if (c >= 'a' && c <= 'z') {
        return c - 'a';
    } else {
        return c - 'A' + 26;
    }
}

struct TrieNode {
    TrieNode* children[ALPHABET_SIZE];
    int count;

    TrieNode() {
        count = 0;
        for (int i = 0; i < ALPHABET_SIZE; ++i) {
            children[i] = nullptr;
        }
    }
};

void insert(TrieNode* root, const string& key) {
    TrieNode* curr = root;
    for (char c : key) {
        int index = char_to_index(c);
        if (!curr->children[index]) {
            curr->children[index] = new TrieNode();
        }
        curr = curr->children[index];
        curr->count++;
    }
}

int search_prefix(TrieNode* root, const string& key) {
    TrieNode* curr = root;
    for (char c : key) {
        int index = char_to_index(c);
        if (!curr->children[index]) {
            return 0;
        }
        curr = curr->children[index];
    }
    return curr->count;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;

    TrieNode* root = new TrieNode();
    for (int i = 0; i < n; ++i) {
        string pattern;
        cin >> pattern;
        insert(root, pattern);
    }

    for (int i = 0; i < m; ++i) {
        string text;
        cin >> text;
        cout << search_prefix(root, text) << '\n';
    }

    // Note: In a real-world scenario, you'd need to deallocate the Trie nodes.
    // For competitive programming, this is often omitted.

    return 0;
}
import java.util.Scanner;

public class Main {
    static final int ALPHABET_SIZE = 52;

    static int charToIndex(char c) {
        if (c >= 'a' && c <= 'z') {
            return c - 'a';
        } else {
            return c - 'A' + 26;
        }
    }

    static class TrieNode {
        TrieNode[] children = new TrieNode[ALPHABET_SIZE];
        int count = 0;
    }

    static void insert(TrieNode root, String key) {
        TrieNode curr = root;
        for (char c : key.toCharArray()) {
            int index = charToIndex(c);
            if (curr.children[index] == null) {
                curr.children[index] = new TrieNode();
            }
            curr = curr.children[index];
            curr.count++;
        }
    }

    static int searchPrefix(TrieNode root, String key) {
        TrieNode curr = root;
        for (char c : key.toCharArray()) {
            int index = charToIndex(c);
            if (curr.children[index] == null) {
                return 0;
            }
            curr = curr.children[index];
        }
        return curr.count;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();

        TrieNode root = new TrieNode();
        for (int i = 0; i < n; i++) {
            String pattern = sc.next();
            insert(root, pattern);
        }

        for (int i = 0; i < m; i++) {
            String text = sc.next();
            System.out.println(searchPrefix(root, text));
        }
    }
}
import sys

ALPHABET_SIZE = 52

def char_to_index(c):
    if 'a' <= c <= 'z':
        return ord(c) - ord('a')
    else:
        return ord(c) - ord('A') + 26

def make_trie_node():
    return {'children': [None] * ALPHABET_SIZE, 'count': 0}

def insert(root, key):
    curr = root
    for char in key:
        index = char_to_index(char)
        if curr['children'][index] is None:
            curr['children'][index] = make_trie_node()
        curr = curr['children'][index]
        curr['count'] += 1

def search_prefix(root, key):
    curr = root
    for char in key:
        index = char_to_index(char)
        if curr['children'][index] is None:
            return 0
        curr = curr['children'][index]
    return curr['count']

def main():
    try:
        n, m = map(int, sys.stdin.readline().split())
        
        root = make_trie_node()
        for _ in range(n):
            pattern = sys.stdin.readline().strip()
            insert(root, pattern)
            
        for _ in range(m):
            text = sys.stdin.readline().strip()
            print(search_prefix(root, text))
    except (IOError, ValueError):
        return

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:Trie 树(字典树)。
  • 时间复杂度
    • 构建:设所有模式串的总长度为 。构建 Trie 树的时间复杂度为
    • 查询:设查询串的长度为 。每次查询的时间复杂度为 次查询的总时间复杂度为
  • 空间复杂度:Trie 树的空间复杂度与所有模式串的总长度成正比,即