题目链接

萌芽

题目描述

管理一棵字符串树(本质为一棵前缀树/Trie)。树中每条自根到叶的路径对应一条已存储的字符串。

现在需要连续提交 个字符串 。对于当前提交的字符串

  • 若树中不存在 为前缀的任何字符串,则将 插入到字符串树中。
  • 否则,输出树中以 为前缀的字符串数量(不插入)。

请你依次处理所有询问并输出对应结果。

解题思路

这道题目的核心是处理字符串前缀的查询和插入操作,这是 Trie (前缀树) 数据结构的典型应用场景。

Trie 是一种树形结构,它的每个节点代表一个字符串前缀。从根节点到任意一个节点的路径,构成了一个字符串。每个节点通常包含一个指向子节点的指针数组(或哈希表)和一个标记,用来记录是否有字符串在此节点结束,以及以此为前缀的字符串数量等信息。

算法步骤

  1. 构建 Trie 节点: 我们需要定义一个 Trie 节点结构。每个节点需要包含以下信息:

    • children: 一个大小为 26 的数组(或哈希表),用于存放指向其子节点的指针,分别对应 'a' 到 'z'。
    • count: 一个整数,用于记录以从根节点到当前节点所表示的前缀为前缀的字符串总数。
  2. 插入操作 (Insert):

    • 要将一个字符串 s 插入 Trie,我们从根节点开始,逐个遍历 s 的字符。
    • 对于每个字符 c,我们检查当前节点是否已经有指向对应 c 的子节点。
    • 如果没有,就创建一个新的 Trie 节点并连接。
    • 然后,移动到该子节点,并将其 count 值加一。
    • 重复此过程,直到 s 的所有字符都处理完毕。
  3. 查询操作 (Query):

    • 要查询以字符串 q 为前缀的字符串数量,我们同样从根节点开始,逐个遍历 q 的字符。
    • 对于每个字符 c,我们沿着指向 c 的子节点路径向下移动。
    • 如果在任何一步,对应的子节点不存在,说明 Trie 中没有任何字符串以 q 为前缀。
    • 如果成功遍历完 q 的所有字符并到达了某个节点,那么该节点的 count 值就是以 q 为前缀的字符串总数。
  4. 主逻辑:

    • 首先,初始化一个空的 Trie (只有一个根节点)。
    • 读取 个初始字符串,并将它们一一插入到 Trie 中。
    • 然后,循环 次,处理每个查询字符串 q
      • 调用查询函数,查找以 q 为前缀的字符串数量。
      • 如果返回的数量为 0,说明不存在这样的前缀,则调用插入函数将 q 插入 Trie。
      • 如果返回的数量大于 0,则直接输出这个数量。

这种方法可以高效地处理所有操作。单次插入或查询的复杂度与字符串的长度成正比,与 Trie 中已有的字符串数量无关。

代码

#include <iostream>
#include <string>
#include <vector>
#include <memory>

using namespace std;

const int ALPHABET_SIZE = 26;
const int MAX_NODES = 2000005; // (10^4 + 10^4) * 100 + 5

struct TrieNode {
    int children[ALPHABET_SIZE];
    int count;
};

TrieNode trie_pool[MAX_NODES];
int node_idx = 1; // 0 is root, nodes are allocated from 1

void insert(const string& key) {
    int current_node = 0;
    for (char ch : key) {
        int index = ch - 'a';
        if (trie_pool[current_node].children[index] == 0) {
            trie_pool[current_node].children[index] = node_idx++;
        }
        current_node = trie_pool[current_node].children[index];
        trie_pool[current_node].count++;
    }
}

int query(const string& key) {
    int current_node = 0;
    for (char ch : key) {
        int index = ch - 'a';
        if (trie_pool[current_node].children[index] == 0) {
            return 0;
        }
        current_node = trie_pool[current_node].children[index];
    }
    return trie_pool[current_node].count;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, m;
    cin >> n >> m;

    // Root is at index 0, implicitly initialized to 0s

    for (int i = 0; i < n; ++i) {
        string s;
        cin >> s;
        insert(s);
    }

    for (int i = 0; i < m; ++i) {
        string q;
        cin >> q;
        int result = query(q);
        if (result == 0) {
            insert(q);
        } else {
            cout << result << "\n";
        }
    }

    return 0;
}
import java.util.Scanner;

public class Main {
    static final int ALPHABET_SIZE = 26;
    static final int MAX_NODES = 2000005; // (10^4 + 10^4) * 100 + 5
    
    static class TrieNode {
        int[] children = new int[ALPHABET_SIZE];
        int count = 0;
    }

    static TrieNode[] trie = new TrieNode[MAX_NODES];
    static int nodeIdx = 1; // 0 is root, nodes are allocated from 1

    static {
        for (int i = 0; i < MAX_NODES; i++) {
            trie[i] = new TrieNode();
        }
    }

    public static void insert(String key) {
        int currentNode = 0; // Start from root
        for (char ch : key.toCharArray()) {
            int index = ch - 'a';
            if (trie[currentNode].children[index] == 0) {
                trie[currentNode].children[index] = nodeIdx++;
            }
            currentNode = trie[currentNode].children[index];
            trie[currentNode].count++;
        }
    }

    public static int query(String key) {
        int currentNode = 0; // Start from root
        for (char ch : key.toCharArray()) {
            int index = ch - 'a';
            if (trie[currentNode].children[index] == 0) {
                return 0;
            }
            currentNode = trie[currentNode].children[index];
        }
        return trie[currentNode].count;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int m = sc.nextInt();

        // Root is at index 0, already initialized
        
        for (int i = 0; i < n; i++) {
            String s = sc.next();
            insert(s);
        }

        for (int i = 0; i < m; i++) {
            String q = sc.next();
            int result = query(q);
            if (result == 0) {
                insert(q);
            } else {
                System.out.println(result);
            }
        }
    }
}
import sys

class TrieNode:
    def __init__(self):
        self.children = {}
        self.count = 0

def insert(root, key):
    p_crawl = root
    for char in key:
        if char not in p_crawl.children:
            p_crawl.children[char] = TrieNode()
        p_crawl = p_crawl.children[char]
        p_crawl.count += 1

def query(root, key):
    p_crawl = root
    for char in key:
        if char not in p_crawl.children:
            return 0
        p_crawl = p_crawl.children[char]
    return p_crawl.count

def main():
    try:
        input = sys.stdin.readline
        n, m = map(int, input().split())
        
        root = TrieNode()

        for _ in range(n):
            s = input().strip()
            insert(root, s)

        for _ in range(m):
            q = input().strip()
            result = query(root, q)
            if result == 0:
                insert(root, q)
            else:
                sys.stdout.write(str(result) + '\n')

    except (IOError, ValueError):
        return

main()

算法及复杂度

  • 算法:Trie (前缀树)
  • 时间复杂度:,其中 是所有输入字符串(包括初始字符串和查询字符串)的总长度。每次插入或查询操作的复杂度都与当前字符串的长度成正比。
  • 空间复杂度:,在最坏情况下(所有字符串没有公共前缀),Trie 的节点数约等于所有字符串的总长度。