字典树—模板(静态数组实现)

数组变量

const int MAX=5e6+10;

int trie[MAX][26];//节点,trie[i][j] 表示结点 i 是否有去往字母 (j+'a') 的分支
int pass[MAX];//记录特定前缀的数量
int is_end[MAX];//记录是否是一个完整的字符串
int cnt=1;//当前节点编号,就是这颗字典树使用到了多少个结点

初始化

在多测时需要用到,将所有变量都清空

void Initial()//初始化
{
	for(int i=0;i<=cnt;i++)
	{
		for(int j=0;j<26;j++)trie[i][j]=0;
		pass[i]=is_end[i]=0;
	}
}

插入字符串

void insert(string& s)//创建字典树
{
	int cur=1;
	pass[cur]++;
	for(char ch:s)
	{
		int c=ch-'a';
		if(trie[cur][c]==0)
			trie[cur][c]=++cnt;
		cur=trie[cur][c];
		pass[cur]++;
	}
	is_end[cur]++;
}

查询前缀

bool findPre(string& s)//查询前缀 pre 是否存在
{
	int cur=1;
	for(char ch:s)
	{
		int c=ch-'a';
		if(trie[cur][c]==0)return false;
		cur=trie[cur][c];
	}
	return true;
}

查询字符串

bool find(string& s)//查询字符串 s 是否存在
{
	int cur=1;
	for(char ch:s)
	{
		int c=ch-'a';
		if(trie[cur][c]==0)return false;
		cur=trie[cur][c];
	}
	return is_end[cur]>0;
}

删除字符串

void Delete(string& s)
{
	if(find(s))//存在这个字符串才要删
	{
		int cur=1;
		for(char ch:s)
		{
			int c=ch-'a';
			if(--pass[trie[cur][c]] == 0)
			{
				trie[cur][c]=0;
				return ;
			}
			cur=trie[cur][c];
		}
		is_end[cur]--;
	}
}

字典树的实现


const int MAX=5e6+10;

class Solution {
  public:
    int trie[MAX][26];//节点
    int pass[MAX];//记录特定前缀的数量
    int is_end[MAX];//记录是否是一个完整的字符串
    int cnt = 1; //当前节点编号

    void insert(string& s) { //创建字典树
        int cur = 1;
        pass[cur]++;
        for (char ch : s) {
            int c = ch - 'a';
            if (trie[cur][c] == 0)
                trie[cur][c] = ++cnt;
            cur = trie[cur][c];
            pass[cur]++;
        }
        is_end[cur]++;
    }

    bool find(string& s) { //查询字符串 s 是否存在
        int cur = 1;
        for (char ch : s) {
            int c = ch - 'a';
            if (trie[cur][c] == 0)return false;
            cur = trie[cur][c];
        }
        return is_end[cur] > 0;
    }

    int cntPre(string& s) { //查询前缀 s count
        int cur = 1;
        for (char ch : s) {
            int c = ch - 'a';
            if (trie[cur][c] == 0)return 0;
            cur = trie[cur][c];
        }
        return pass[cur];
    }

    void Delete(string& s) {
        if (find(s)) {
            int cur = 1;
            for (char ch : s) {
                int c = ch - 'a';
                if (--pass[trie[cur][c]] == 0) {
                    trie[cur][c] = 0;
                    return ;
                }
                cur = trie[cur][c];
            }
            is_end[cur]--;
        }
    }
    vector<string> trieU(vector<vector<string> >& oper) {
        int n=oper.size();
        vector<string> ans;
        for(int i=0;i<n;i++)
        {
            string op=oper[i][0];
            if(op[0]=='1')
            {
                insert(oper[i][1]);
            }
            else if(op[0]=='2')
            {
                Delete(oper[i][1]);
            }
            else if(op[0]=='3')
            {
                if(find(oper[i][1]))ans.push_back("YES");
                else ans.push_back("NO");
            }
            else if(op[0]=='4') {
                ans.push_back(to_string(cntPre(oper[i][1])));
            }
        }
        return ans;
    }
};