AC自动机(加强版)
有 N 个由小写字母组成的模式串以及一个文本串 T 。每个模式串可能会在文本串中出现多次。你需要找出哪些模式串在文本串 T 中出现的次数最多。
每组数据的第一行为一个正整数 N ,表示共有 N 个模式串, 1 ≤ N ≤ 150 。
接下去 N 行,每行一个长度小于等于 70 的模式串。下一行是一个长度小于等于 10^6 的文本串 T 。
输入结束标志为 N=0
对于每组数据,第一行输出模式串最多出现的次数,接下去若干行每行输出一个出现次数最多的模式串,按输入顺序排列。
正解部分
以此题为例说一下 AC 自动机的相关内容 .
1.建立自动机
- 首先将所有的 模式串 加入 Trie树 .
- 按 BFS序 寻找 fail指针,
设当前 BFS 到了 Trie树 中的 i 节点, 代表了字符串 S, 结尾字符为 c, 其父节点为 fa[i],
若 c 不存在, c 的 fail指针 指向的是 Trie树 中除了 S 以外存在的字符串中与 S 有最长公共后缀的字符串对应的 Trie树节点 j .
于是 c 的 fail[i]就可以表示为: fail[i]=ch[fail[fa[i]]][c] . - 若 fail[i] 连向 终止节点, 则该点为 “伪终止节点”, 表示到了 i节点, 等同到了一个终止节点, 需要计数 .
- 为了使下方匹配更加迅速, 记录 i 节点不断跳 fail 遇到的第一个 终止节点 为 last[i],
更新方法: last[i]=end[fail[i]]?fail[i]:last[fail[i]] .
2.匹配字符串
从 文本串 起始位置, Trie树根节点开始匹配, 到 i点的时候,
统计 i点能连到的所有 终止节点 即可, 如下 .
int tmp = i;
while(i){
统计i点答案;
i = last[i];
}
实现部分
#include<bits/stdc++.h>
#define reg register
const int maxn = 1e6 + 10;
int N;
char T[maxn];
char s[155][75];
struct Asw{ int cnt, id; } Ans[maxn];
bool cmp(Asw a, Asw b){ return a.cnt==b.cnt?a.id<b.id:a.cnt > b.cnt; }
struct Ac_auto{
int node_cnt;
int end[20004];
struct Trie{ int vis[30], nxt, is_end, last;} node[20004];
void Init(){
node_cnt = 0;
memset(end, 0, sizeof end);
for(reg int i = 0; i < 20004; i ++)
memset(node[i].vis, 0, sizeof node[i].vis),
node[i].nxt = node[i].is_end = node[i].last = 0;
}
void Add(char *s, int id){
int cur = 0, size = strlen(s);
for(reg int i = 0; i < size; i ++){
int t = s[i] - 'a';
if(!node[cur].vis[t]) node[cur].vis[t] = ++ node_cnt;
cur = node[cur].vis[t];
}
node[cur].is_end ++, end[cur] = id;
}
void BFS(){
std::queue <int> Q;
for(reg int i = 0; i < 26; i ++)
if(node[0].vis[i]) Q.push(node[0].vis[i]);
while(!Q.empty()){
int ft = Q.front(); Q.pop();
node[ft].last = node[node[ft].nxt].is_end?node[ft].nxt:node[node[ft].nxt].last;
for(reg int i = 0; i < 26; i ++)
if(node[ft].vis[i]){
node[node[ft].vis[i]].nxt = node[node[ft].nxt].vis[i];
Q.push(node[ft].vis[i]);
}
else node[ft].vis[i] = node[node[ft].nxt].vis[i];
}
}
void Find(char *T){
int cur = 0, size = strlen(T);
for(reg int i = 0; i < size; i ++){
int t = T[i] - 'a';
cur = node[cur].vis[t];
int tmp = cur;
while(tmp){
if(node[tmp].is_end) Ans[end[tmp]].cnt += node[tmp].is_end;
tmp = node[tmp].last;
}
}
}
} Ac_t;
void Work(){
Ac_t.Init();
for(reg int i = 1; i <= N; i ++) Ans[i].cnt = 0, Ans[i].id = i;
for(reg int i = 1; i <= N; i ++) scanf("%s", s[i]), Ac_t.Add(s[i], i);
Ac_t.BFS();
scanf("%s", T);
Ac_t.Find(T);
std::sort(Ans+1, Ans+N+1, cmp);
int t = 1;
printf("%d\n", Ans[1].cnt);
printf("%s\n", s[Ans[1].id]);
while(t < N && Ans[t].cnt == Ans[t+1].cnt) printf("%s\n", s[Ans[++ t].id]);
}
int main(){
while(~scanf("%d", &N) && N) Work();
return 0;
}