#include<bits/stdc++.h>
using namespace std;
const int maxn =1000000+100;
const int SIGMA_SIZE=26;
const int maxnode =1000000+100;
int n,ans;
bool vis[maxn];
int ch[maxnode][SIGMA_SIZE];
int val[maxnode];
int idx(char c){return c-'a';};
struct Trie{
int sz;
Trie(){ sz = 1; memset(ch[0],0,sizeof(ch[0])); memset(vis,0,sizeof(vis));};
void insert(char *s){
int u=0,n=strlen(s);
for(int i=0;i<n;i++){
int c=idx(s[i]);
if(!ch[u][c]){
memset(ch[sz],0,sizeof(ch[sz]));
val[sz] = 0;
ch[u][c] = sz++;
}
u=ch[u][c];
}
val[u]++;
}
};
//AC自动机
int last[maxn],f[maxn];
void print(int j){
if(j&&!vis[j]){
ans+=val[j]; vis[j]=1;
print(last[j]);
}
}
int getFail(){
queue<int>q;
f[0] = 0;
for(int c=0;c<SIGMA_SIZE;c++){
int u = ch[0][c];
if(u) { f[u]=0; q.push(u); last[u] = 0;}
}
while(!q.empty()){
int r=q.front(); q.pop();
for(int c=0;c<SIGMA_SIZE;c++){
int u=ch[r][c];
if(!u){
ch[r][c]=ch[f[r]][c];
continue;
}
q.push(u);
int v=f[r];
//while(v&&!ch[v][c]) v=f[v];
f[u] = ch[v][c];
last[u] = val[f[u]]?f[u]:last[f[u]];
}
}
}
void find_T(char* T){
int n=strlen(T);
int j=0;
for(int i=0;i<n;i++){
int c=idx(T[i]);
j=ch[j][c];
if(val[j]) print(j);
else if(last[j]) print(last[j]);
}
}
char tmp[105];
char text[1000000+1000];
int main(){
int T; cin>>T;
while(T--){
scanf("%d",&n);
Trie trie;
ans = 0;
for(int i=0;i<n;i++){
scanf("%s",tmp);
trie.insert(tmp);
}
getFail();
scanf("%s",text);
find_T(text);
cout<<ans<<endl;
}
return 0;
}注释后:
#include<bits/stdc++.h>
using namespace std;
const int maxn =1000000+100;
const int SIGMA_SIZE=26;
const int maxnode =1000000+100;
int n,ans;
bool vis[maxn];
int ch[maxnode][SIGMA_SIZE];
int val[maxnode];
int idx(char c){return c-'a';}; //将字符转为数字,注意变通
struct Trie{
int sz;
Trie(){ sz = 1; memset(ch[0],0,sizeof(ch[0])); memset(vis,0,sizeof(vis));};
void insert(char *s){
int u=0,n=strlen(s);
for(int i=0;i<n;i++){
int c=idx(s[i]);
if(!ch[u][c]){
memset(ch[sz],0,sizeof(ch[sz]));
val[sz] = 0;
ch[u][c] = sz++;
}
u=ch[u][c];
}
val[u]++;
}
};
//AC自动机
int last[maxn],f[maxn];
void print(int j){
if(j&&!vis[j]){
ans+=val[j]; vis[j]=1;
print(last[j]);
}
}
int getFail(){
queue<int>q;
f[0] = 0;
for(int c=0;c<SIGMA_SIZE;++c){
int u = ch[0][c];
if(u) { f[u]=0; q.push(u); last[u] = 0;}
}
while(!q.empty()){
int r=q.front(); q.pop();
for(int c=0;c<SIGMA_SIZE;c++){
int u=ch[r][c];
if(!u){ //失配时直接连接fail指针
ch[r][c]=ch[f[r]][c];
continue;
}
q.push(u);
int v=f[r];
//while(v&&!ch[v][c]) v=f[v]; //多余
f[u] = ch[v][c]; //孩子结点和它还有它长辈的fail指针的字符都是相同的,除了根结点
//如果父亲节点的fail指针 v 的孩子结点没有节点 c ,那么就访问 v 的fail指针继续这类操作直到当前节点是 0 或者
//这个节点的孩子节点有节点 c ,然后 u 的fail指针f[u]指向这个 c 节点
//(因为处理这个字典树是由浅到深处理的,所以ch[v][c]已经处理过了,直接f[u]=ch[v][c]就可以了 )
last[u] = val[f[u]]?f[u]:last[f[u]]; //记录单词的结尾的位置--假设该字符是某个单词的结尾
//(如果这个字符不是任何单词的结尾就指向根结点 0 )
//如果该结点的fail指针是单词的结尾,指向fail指针继续匹配,
//否则看fail指针前面是否有字符是单词的结尾(因为是先处理父亲再处理儿子的,所以很容易做到)
//如果有就指向最近的一个单词结尾,没有就指向空
}
}
}
//儿子的fail指针根据父节点和长辈(包括父节点) 的fail指针来确定的,如果儿子有fail指针,儿子和fail指针指向的字符一定是相同的
void find_T(char* T){
int n=strlen(T);
int j=0;
for(int i=0;i<n;i++){
int c=idx(T[i]);
j=ch[j][c]; //如果再失配,j会变为0重新开始匹配
//getfail函数已经预处理过了,预处理时如果失配会指向fail指针
//所以现在如果再失配则表示和所有的fail指针都失配,就重新开始匹配
if(val[j]) print(j); //匹配到该字符时可能已经完全匹配到多个单词了,全部标记
else if(last[j]) print(last[j]); //同上
}
}
char tmp[105];
char text[1000000+1000];
int main(){
int T; cin>>T;
while(T--){
scanf("%d",&n);
Trie trie;
ans = 0;
for(int i=0;i<n;i++){
scanf("%s",tmp);
trie.insert(tmp);
}
getFail();
scanf("%s",text);
find_T(text);
cout<<ans<<endl;
}
return 0;
}
将模板改了改写了hdu2896,模板不太熟悉,还有待优化202ms:
#include<bits/stdc++.h>
#define maxn 100100
#define SIGMA_SIZE 128
#define maxnode 100100
using namespace std;
int n,ans[3],cnt=0,b[3];
bool vis[maxn];
int ch[maxnode][SIGMA_SIZE];
int val[maxnode];
int idx(char c){return c-'a';};
struct Trie{
int sz;
Trie(){ sz = 1; memset(ch[0],0,sizeof(ch[0])); memset(vis,0,sizeof(vis));};
void insert(char *s){
int u=0,n=strlen(s);
for(int i=0;i<n;i++){
int c=idx(s[i]);
if(!ch[u][c]){
memset(ch[sz],0,sizeof(ch[sz]));
val[sz] = 0;
ch[u][c] = sz++;
}
u=ch[u][c];
}
val[u]=++cnt;
}
};
int last[maxn],f[maxn];
void print(int j){
if(j&&!vis[j]){
ans[++cnt]=val[j]; vis[j]=1;
b[cnt]=j;
print(last[j]);
}
}
int getFail(){
queue<int>q;
f[0] = 0;
for(int c=0;c<SIGMA_SIZE;c++){
int u = ch[0][c];
if(u) { f[u]=0; q.push(u); last[u] = 0;}
}
while(!q.empty()){
int r=q.front(); q.pop();
for(int c=0;c<SIGMA_SIZE;++c){
int u=ch[r][c];
if(!u){
ch[r][c]=ch[f[r]][c];
continue;
}
q.push(u);
int v=f[r];
while(v&&!ch[v][c]) v=f[v];
f[u] = ch[v][c];
last[u] = val[f[u]]?f[u]:last[f[u]];
}
}
}
void find_T(char* T){
int n=strlen(T);
int j=0;
for(int i=0;i<n;i++){
int c=idx(T[i]);
j=ch[j][c];
if(val[j]) print(j);
else if(last[j]) print(last[j]);
}
}
char tmp[205];
char text[10005];
int main(){
int n,m,j=0,ans1=0;
scanf("%d",&n);
Trie trie;
for(int i=0;i<n;i++){
scanf("%s",tmp);
trie.insert(tmp);
}
getFail();
scanf("%d",&m);
while(m--) {
scanf("%s",text);
cnt=-1,++j;
find_T(text);
if(cnt>=0) {
sort(ans,ans+3);
printf("web %d:",j);
for(int i=0;i<3;++i) {
if(ans[i]>0) {
printf(" %d",ans[i]);
ans[i]=0;
}
vis[b[i]]=false;
b[i]=0;
}
puts("");
++ans1;
}
}
printf("total: %d\n",ans1);
return 0;
}hdu3065,有个坑点就是要多组输入,140ms--不会优化了:
#include<bits/stdc++.h>
#define maxn 100100
#define SIGMA_SIZE 26
#define maxnode 100100
using namespace std;
int n,ans[1005],cnt=0;
int ch[maxnode][SIGMA_SIZE];
int val[maxnode];
int idx(char c){
if(c<'A'||c>'Z') return -1;
return c-'A';
};
struct Trie{
int sz;
Trie(){ sz = 1; memset(ch[0],0,sizeof(ch[0]));};
void insert(char *s){
int u=0,n=strlen(s);
for(int i=0;i<n;i++){
int c=idx(s[i]);
if(!ch[u][c]){
memset(ch[sz],0,sizeof(ch[sz]));
val[sz] = 0;
ch[u][c] = sz++;
}
u=ch[u][c];
}
val[u]=++cnt;
}
};
int last[maxn],f[maxn];
void print(int j){
if(j){
++ans[val[j]];
print(last[j]);
}
}
int getFail(){
queue<int>q;
f[0] = 0;
for(int c=0;c<SIGMA_SIZE;c++){
int u = ch[0][c];
if(u) { f[u]=0; q.push(u); last[u] = 0;}
}
while(!q.empty()){
int r=q.front(); q.pop();
for(int c=0;c<SIGMA_SIZE;++c){
int u=ch[r][c];
if(!u){
ch[r][c]=ch[f[r]][c];
continue;
}
q.push(u);
int v=f[r];
while(v&&!ch[v][c]) v=f[v];
f[u] = ch[v][c];
last[u] = val[f[u]]?f[u]:last[f[u]];
}
}
}
void find_T(char* T){
int n=strlen(T);
int j=0;
for(int i=0;i<n;i++){
int c=idx(T[i]);
if(c==-1) {
j=0;
continue;
}
j=ch[j][c];
if(val[j]) print(j);
else if(last[j]) print(last[j]);
}
}
char tmp[1005][55];
char text[2000005];
int main(){
while(~scanf("%d",&n)) {
Trie trie;
cnt=0;
for(int i=1;i<=n;i++){
scanf("%s",tmp[i]);
trie.insert(tmp[i]);
}
getFail();
scanf("%s",text);
find_T(text);
for(int i=1;i<=n;++i) {
if(ans[i]>0) {
printf("%s: %d\n",tmp[i],ans[i]);
ans[i]=0;
}
}
}
return 0;
}
京公网安备 11010502036488号