PEEK134 【模板】Trie 字典树
题目链接
题目描述
给定 个模式串和
次查询。对于每次查询,给定一个文本串
,你需要统计有多少个模式串是以
作为前缀的。
解题思路
这是一个经典的使用 Trie 树(字典树或前缀树) 解决的问题。Trie 树是一种专门用于高效存储和检索字符串集合的数据结构,特别擅长处理前缀相关的查询。
Trie 树的结构和性质:
- 它是一个多叉树,树的根节点不代表任何字符。
- 从根节点到任意一个节点的路径构成一个字符串(通常是某个插入单词的前缀)。
- 每个节点包含一个指向子节点的指针数组(或哈希表)。数组的大小是字符集的大小(例如,对于小写字母是 26,对于大小写字母是 52)。
- 为了解决特定问题,节点上通常还会附带一些额外信息。
本题的解法:
-
构建 Trie 树并预处理计数
- 我们定义 Trie 树的每个节点。除了指向子节点的指针外,每个节点还需要一个计数器
count
。这个count
的含义是:有多少个插入的模式串经过了这个节点。 - 遍历
个模式串,将它们逐个插入 Trie 树中。
- 插入过程如下:从根节点开始,对于模式串的每一个字符,找到对应的子节点路径。如果路径不存在,则创建一个新节点。然后,沿着路径移动到子节点,并将当前节点的
count
值加一。 - 经过这个过程,每个节点
u
的count
值就精确地记录了数据集中以“根到u
的路径所代表的字符串”为前缀的模式串总数。
- 我们定义 Trie 树的每个节点。除了指向子节点的指针外,每个节点还需要一个计数器
-
处理查询
- 对于每个查询的文本串
,我们从 Trie 树的根节点开始,沿着
的字符路径进行匹配。
- 遍历
的每个字符,在当前节点查找是否存在通向该字符的子节点。
- 如果存在,就移动到该子节点。
- 如果不存在,说明没有任何模式串是以
作为前缀的。查询结果为 0,可以直接结束本次查询。
- 如果成功遍历完
的所有字符,我们会到达一个节点,设为
u
。这个节点u
所代表的路径就是字符串。
- 根据我们之前的定义,
u
节点的count
值就是以为前缀的模式串总数。我们直接返回这个值即可。
- 对于每个查询的文本串
这种方法的优势在于,查询的时间复杂度只与查询串的长度有关,与模式串的数量和长度无关,因此非常高效。
字符集映射:
由于字符串包含大小写英文字母,字符集大小为 52。我们需要一个函数将字符 'a'-'z'
和 'A'-'Z'
映射到 0-51
的索引。
'a' - 'z'
->0 - 25
'A' - 'Z'
->26 - 51
代码
#include <iostream>
#include <vector>
#include <string>
using namespace std;
const int ALPHABET_SIZE = 52;
// 字符到索引的映射
int char_to_index(char c) {
if (c >= 'a' && c <= 'z') {
return c - 'a';
} else {
return c - 'A' + 26;
}
}
struct TrieNode {
TrieNode* children[ALPHABET_SIZE];
int count;
TrieNode() {
count = 0;
for (int i = 0; i < ALPHABET_SIZE; ++i) {
children[i] = nullptr;
}
}
};
void insert(TrieNode* root, const string& key) {
TrieNode* curr = root;
for (char c : key) {
int index = char_to_index(c);
if (!curr->children[index]) {
curr->children[index] = new TrieNode();
}
curr = curr->children[index];
curr->count++;
}
}
int search_prefix(TrieNode* root, const string& key) {
TrieNode* curr = root;
for (char c : key) {
int index = char_to_index(c);
if (!curr->children[index]) {
return 0;
}
curr = curr->children[index];
}
return curr->count;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, m;
cin >> n >> m;
TrieNode* root = new TrieNode();
for (int i = 0; i < n; ++i) {
string pattern;
cin >> pattern;
insert(root, pattern);
}
for (int i = 0; i < m; ++i) {
string text;
cin >> text;
cout << search_prefix(root, text) << '\n';
}
// Note: In a real-world scenario, you'd need to deallocate the Trie nodes.
// For competitive programming, this is often omitted.
return 0;
}
import java.util.Scanner;
public class Main {
static final int ALPHABET_SIZE = 52;
static int charToIndex(char c) {
if (c >= 'a' && c <= 'z') {
return c - 'a';
} else {
return c - 'A' + 26;
}
}
static class TrieNode {
TrieNode[] children = new TrieNode[ALPHABET_SIZE];
int count = 0;
}
static void insert(TrieNode root, String key) {
TrieNode curr = root;
for (char c : key.toCharArray()) {
int index = charToIndex(c);
if (curr.children[index] == null) {
curr.children[index] = new TrieNode();
}
curr = curr.children[index];
curr.count++;
}
}
static int searchPrefix(TrieNode root, String key) {
TrieNode curr = root;
for (char c : key.toCharArray()) {
int index = charToIndex(c);
if (curr.children[index] == null) {
return 0;
}
curr = curr.children[index];
}
return curr.count;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
TrieNode root = new TrieNode();
for (int i = 0; i < n; i++) {
String pattern = sc.next();
insert(root, pattern);
}
for (int i = 0; i < m; i++) {
String text = sc.next();
System.out.println(searchPrefix(root, text));
}
}
}
import sys
ALPHABET_SIZE = 52
def char_to_index(c):
if 'a' <= c <= 'z':
return ord(c) - ord('a')
else:
return ord(c) - ord('A') + 26
def make_trie_node():
return {'children': [None] * ALPHABET_SIZE, 'count': 0}
def insert(root, key):
curr = root
for char in key:
index = char_to_index(char)
if curr['children'][index] is None:
curr['children'][index] = make_trie_node()
curr = curr['children'][index]
curr['count'] += 1
def search_prefix(root, key):
curr = root
for char in key:
index = char_to_index(char)
if curr['children'][index] is None:
return 0
curr = curr['children'][index]
return curr['count']
def main():
try:
n, m = map(int, sys.stdin.readline().split())
root = make_trie_node()
for _ in range(n):
pattern = sys.stdin.readline().strip()
insert(root, pattern)
for _ in range(m):
text = sys.stdin.readline().strip()
print(search_prefix(root, text))
except (IOError, ValueError):
return
if __name__ == "__main__":
main()
算法及复杂度
- 算法:Trie 树(字典树)。
- 时间复杂度:
- 构建:设所有模式串的总长度为
。构建 Trie 树的时间复杂度为
。
- 查询:设查询串的长度为
。每次查询的时间复杂度为
。
次查询的总时间复杂度为
。
- 构建:设所有模式串的总长度为
- 空间复杂度:Trie 树的空间复杂度与所有模式串的总长度成正比,即
。