#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 2*1e6+9; int trie[maxn][26]; //字典树 int cntword[maxn]; //记录该单词出现次数 int fail[maxn]; //失败时的回溯指针 int cnt = 0; void insertWords(string s){ int root = 0; for(int i=0;i<s.size();i++){ int next = s[i] - 'a'; if(!trie[root][next]) trie[root][next] = ++cnt; root = trie[root][next]; } cntword[root]++; //当前节点单词数+1 } void getFail(){ queue <int>q; for(int i=0;i<26;i++){ //将第二层所有出现了的字母扔进队列 if(trie[0][i]){ fail[trie[0][i]] = 0; q.push(trie[0][i]); } } //fail[now] ->当前节点now的失败指针指向的地方 ////tire[now][i] -> 下一个字母为i+'a'的节点的下标为tire[now][i] while(!q.empty()){ int now = q.front(); q.pop(); for(int i=0;i<26;i++){ //查询26个字母 if(trie[now][i]){ //如果有这个子节点为字母i+'a',则 //让这个节点的失败指针指向(((他父亲节点)的失败指针所指向的那个节点)的下一个节点) //有点绕,为了方便理解特意加了括号 fail[trie[now][i]] = trie[fail[now]][i]; q.push(trie[now][i]); } else//否则就让当前节点的这个子节点 //指向当前节点fail指针的这个子节点 trie[now][i] = trie[fail[now]][i]; } } } int query(string s){ int now = 0,ans = 0; for(int i=0;i<s.size();i++){ //遍历文本串 now = trie[now][s[i]-'a']; //从s[i]点开始寻找 for(int j=now;j && cntword[j]!=-1;j=fail[j]){ //一直向下寻找,直到匹配失败(失败指针指向根或者当前节点已找过). ans += cntword[j]; cntword[j] = -1; //将遍历国后的节点标记,防止重复计算 } } return ans; } int main() { int n; string s; cin >> n; for(int i=0;i<n;i++){ cin >> s ; insertWords(s); } fail[0] = 0; getFail(); cin >> s ; cout << query(s) << endl; return 0; }
用结构体
#include<bits/stdc++.h> #define maxn 1000001 using namespace std; struct kkk{ int son[26],flag,fail; }trie[maxn]; int n,cnt; char s[1000001]; queue<int >q; void insert(char* s){ int u=1,len=strlen(s); for(int i=0;i<len;i++){ int v=s[i]-'a'; if(!trie[u].son[v])trie[u].son[v]=++cnt; u=trie[u].son[v]; } trie[u].flag++; } void getFail(){ for(int i=0;i<26;i++)trie[0].son[i]=1; //初始化0的所有儿子都是1 q.push(1);trie[1].fail=0; //将根压入队列 while(!q.empty()){ int u=q.front();q.pop(); for(int i=0;i<26;i++){ //遍历所有儿子 int v=trie[u].son[i]; //处理u的i儿子的fail,这样就可以不用记父亲了 int Fail=trie[u].fail; //就是fafail,trie[Fail].son[i]就是和v值相同的点 if(!v){trie[u].son[i]=trie[Fail].son[i];continue;} //不存在该节点,第二种情况 trie[v].fail=trie[Fail].son[i]; //第三种情况,直接指就可以了 q.push(v); //存在实节点才压入队列 } } } int query(char* s){ int u=1,ans=0,len=strlen(s); for(int i=0;i<len;i++){ int v=s[i]-'a'; int k=trie[u].son[v]; //跳Fail while(k>1&&trie[k].flag!=-1){ //经过就不统计了 ans+=trie[k].flag,trie[k].flag=-1; //累加上这个位置的模式串个数,标记已经过 k=trie[k].fail; //继续跳Fail } u=trie[u].son[v]; //到下一个儿子 } return ans; } int main(){ cnt=1; //代码实现细节,编号从1开始 scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%s",s); insert(s); } getFail(); scanf("%s",s); printf("%d\n",query(s)); return 0; }
优化ac自动机
p5357
题目链接
#include<bits/stdc++.h> #define maxn 2000001 using namespace std; char s[maxn],T[maxn]; int n,cnt,vis[200051],ans,in[maxn],Map[maxn]; struct kkk{ int son[26],fail,flag,ans; }trie[maxn]; queue<int>q; void insert(char* s,int num){ int u=1,len=strlen(s); for(int i=0;i<len;++i){ int v=s[i]-'a'; if(!trie[u].son[v])trie[u].son[v]=++cnt; u=trie[u].son[v]; } if(!trie[u].flag)trie[u].flag=num; Map[num]=trie[u].flag; } void getFail(){ for(int i=0;i<26;i++)trie[0].son[i]=1; q.push(1); while(!q.empty()){ int u=q.front();q.pop(); int Fail=trie[u].fail; for(int i=0;i<26;++i){ int v=trie[u].son[i]; if(!v){trie[u].son[i]=trie[Fail].son[i];continue;} trie[v].fail=trie[Fail].son[i]; in[trie[v].fail]++; q.push(v); } } } void topu(){ for(int i=1;i<=cnt;++i) if(in[i]==0)q.push(i); //将入度为0的点全部压入队列里 while(!q.empty()){ int u=q.front();q.pop();vis[trie[u].flag]=trie[u].ans; //如果有flag标记就更新vis数组 int v=trie[u].fail;in[v]--; //将唯一连出去的出边fail的入度减去(拓扑排序的操作) trie[v].ans+=trie[u].ans; //更新fail的ans值 if(in[v]==0)q.push(v); //拓扑排序常规操作 } } void query(char* s){ int u=1,len=strlen(s); for(int i=0;i<len;++i) u=trie[u].son[s[i]-'a'],trie[u].ans++; } int main(){ scanf("%d",&n); cnt=1; for(int i=1;i<=n;++i){ scanf("%s",s); insert(s,i); }getFail();scanf("%s",T); query(T);topu(); for(int i=1;i<=n;++i)printf("%d\n",vis[Map[i]]); }