#include<bits/stdc++.h> using namespace std; const int maxn =1000000+100; const int SIGMA_SIZE=26; const int maxnode =1000000+100; int n,ans; bool vis[maxn]; int ch[maxnode][SIGMA_SIZE]; int val[maxnode]; int idx(char c){return c-'a';}; struct Trie{ int sz; Trie(){ sz = 1; memset(ch[0],0,sizeof(ch[0])); memset(vis,0,sizeof(vis));}; void insert(char *s){ int u=0,n=strlen(s); for(int i=0;i<n;i++){ int c=idx(s[i]); if(!ch[u][c]){ memset(ch[sz],0,sizeof(ch[sz])); val[sz] = 0; ch[u][c] = sz++; } u=ch[u][c]; } val[u]++; } }; //AC自动机 int last[maxn],f[maxn]; void print(int j){ if(j&&!vis[j]){ ans+=val[j]; vis[j]=1; print(last[j]); } } int getFail(){ queue<int>q; f[0] = 0; for(int c=0;c<SIGMA_SIZE;c++){ int u = ch[0][c]; if(u) { f[u]=0; q.push(u); last[u] = 0;} } while(!q.empty()){ int r=q.front(); q.pop(); for(int c=0;c<SIGMA_SIZE;c++){ int u=ch[r][c]; if(!u){ ch[r][c]=ch[f[r]][c]; continue; } q.push(u); int v=f[r]; //while(v&&!ch[v][c]) v=f[v]; f[u] = ch[v][c]; last[u] = val[f[u]]?f[u]:last[f[u]]; } } } void find_T(char* T){ int n=strlen(T); int j=0; for(int i=0;i<n;i++){ int c=idx(T[i]); j=ch[j][c]; if(val[j]) print(j); else if(last[j]) print(last[j]); } } char tmp[105]; char text[1000000+1000]; int main(){ int T; cin>>T; while(T--){ scanf("%d",&n); Trie trie; ans = 0; for(int i=0;i<n;i++){ scanf("%s",tmp); trie.insert(tmp); } getFail(); scanf("%s",text); find_T(text); cout<<ans<<endl; } return 0; }
注释后:
#include<bits/stdc++.h> using namespace std; const int maxn =1000000+100; const int SIGMA_SIZE=26; const int maxnode =1000000+100; int n,ans; bool vis[maxn]; int ch[maxnode][SIGMA_SIZE]; int val[maxnode]; int idx(char c){return c-'a';}; //将字符转为数字,注意变通 struct Trie{ int sz; Trie(){ sz = 1; memset(ch[0],0,sizeof(ch[0])); memset(vis,0,sizeof(vis));}; void insert(char *s){ int u=0,n=strlen(s); for(int i=0;i<n;i++){ int c=idx(s[i]); if(!ch[u][c]){ memset(ch[sz],0,sizeof(ch[sz])); val[sz] = 0; ch[u][c] = sz++; } u=ch[u][c]; } val[u]++; } }; //AC自动机 int last[maxn],f[maxn]; void print(int j){ if(j&&!vis[j]){ ans+=val[j]; vis[j]=1; print(last[j]); } } int getFail(){ queue<int>q; f[0] = 0; for(int c=0;c<SIGMA_SIZE;++c){ int u = ch[0][c]; if(u) { f[u]=0; q.push(u); last[u] = 0;} } while(!q.empty()){ int r=q.front(); q.pop(); for(int c=0;c<SIGMA_SIZE;c++){ int u=ch[r][c]; if(!u){ //失配时直接连接fail指针 ch[r][c]=ch[f[r]][c]; continue; } q.push(u); int v=f[r]; //while(v&&!ch[v][c]) v=f[v]; //多余 f[u] = ch[v][c]; //孩子结点和它还有它长辈的fail指针的字符都是相同的,除了根结点 //如果父亲节点的fail指针 v 的孩子结点没有节点 c ,那么就访问 v 的fail指针继续这类操作直到当前节点是 0 或者 //这个节点的孩子节点有节点 c ,然后 u 的fail指针f[u]指向这个 c 节点 //(因为处理这个字典树是由浅到深处理的,所以ch[v][c]已经处理过了,直接f[u]=ch[v][c]就可以了 ) last[u] = val[f[u]]?f[u]:last[f[u]]; //记录单词的结尾的位置--假设该字符是某个单词的结尾 //(如果这个字符不是任何单词的结尾就指向根结点 0 ) //如果该结点的fail指针是单词的结尾,指向fail指针继续匹配, //否则看fail指针前面是否有字符是单词的结尾(因为是先处理父亲再处理儿子的,所以很容易做到) //如果有就指向最近的一个单词结尾,没有就指向空 } } } //儿子的fail指针根据父节点和长辈(包括父节点) 的fail指针来确定的,如果儿子有fail指针,儿子和fail指针指向的字符一定是相同的 void find_T(char* T){ int n=strlen(T); int j=0; for(int i=0;i<n;i++){ int c=idx(T[i]); j=ch[j][c]; //如果再失配,j会变为0重新开始匹配 //getfail函数已经预处理过了,预处理时如果失配会指向fail指针 //所以现在如果再失配则表示和所有的fail指针都失配,就重新开始匹配 if(val[j]) print(j); //匹配到该字符时可能已经完全匹配到多个单词了,全部标记 else if(last[j]) print(last[j]); //同上 } } char tmp[105]; char text[1000000+1000]; int main(){ int T; cin>>T; while(T--){ scanf("%d",&n); Trie trie; ans = 0; for(int i=0;i<n;i++){ scanf("%s",tmp); trie.insert(tmp); } getFail(); scanf("%s",text); find_T(text); cout<<ans<<endl; } return 0; }
将模板改了改写了hdu2896,模板不太熟悉,还有待优化202ms:
#include<bits/stdc++.h> #define maxn 100100 #define SIGMA_SIZE 128 #define maxnode 100100 using namespace std; int n,ans[3],cnt=0,b[3]; bool vis[maxn]; int ch[maxnode][SIGMA_SIZE]; int val[maxnode]; int idx(char c){return c-'a';}; struct Trie{ int sz; Trie(){ sz = 1; memset(ch[0],0,sizeof(ch[0])); memset(vis,0,sizeof(vis));}; void insert(char *s){ int u=0,n=strlen(s); for(int i=0;i<n;i++){ int c=idx(s[i]); if(!ch[u][c]){ memset(ch[sz],0,sizeof(ch[sz])); val[sz] = 0; ch[u][c] = sz++; } u=ch[u][c]; } val[u]=++cnt; } }; int last[maxn],f[maxn]; void print(int j){ if(j&&!vis[j]){ ans[++cnt]=val[j]; vis[j]=1; b[cnt]=j; print(last[j]); } } int getFail(){ queue<int>q; f[0] = 0; for(int c=0;c<SIGMA_SIZE;c++){ int u = ch[0][c]; if(u) { f[u]=0; q.push(u); last[u] = 0;} } while(!q.empty()){ int r=q.front(); q.pop(); for(int c=0;c<SIGMA_SIZE;++c){ int u=ch[r][c]; if(!u){ ch[r][c]=ch[f[r]][c]; continue; } q.push(u); int v=f[r]; while(v&&!ch[v][c]) v=f[v]; f[u] = ch[v][c]; last[u] = val[f[u]]?f[u]:last[f[u]]; } } } void find_T(char* T){ int n=strlen(T); int j=0; for(int i=0;i<n;i++){ int c=idx(T[i]); j=ch[j][c]; if(val[j]) print(j); else if(last[j]) print(last[j]); } } char tmp[205]; char text[10005]; int main(){ int n,m,j=0,ans1=0; scanf("%d",&n); Trie trie; for(int i=0;i<n;i++){ scanf("%s",tmp); trie.insert(tmp); } getFail(); scanf("%d",&m); while(m--) { scanf("%s",text); cnt=-1,++j; find_T(text); if(cnt>=0) { sort(ans,ans+3); printf("web %d:",j); for(int i=0;i<3;++i) { if(ans[i]>0) { printf(" %d",ans[i]); ans[i]=0; } vis[b[i]]=false; b[i]=0; } puts(""); ++ans1; } } printf("total: %d\n",ans1); return 0; }
hdu3065,有个坑点就是要多组输入,140ms--不会优化了:
#include<bits/stdc++.h> #define maxn 100100 #define SIGMA_SIZE 26 #define maxnode 100100 using namespace std; int n,ans[1005],cnt=0; int ch[maxnode][SIGMA_SIZE]; int val[maxnode]; int idx(char c){ if(c<'A'||c>'Z') return -1; return c-'A'; }; struct Trie{ int sz; Trie(){ sz = 1; memset(ch[0],0,sizeof(ch[0]));}; void insert(char *s){ int u=0,n=strlen(s); for(int i=0;i<n;i++){ int c=idx(s[i]); if(!ch[u][c]){ memset(ch[sz],0,sizeof(ch[sz])); val[sz] = 0; ch[u][c] = sz++; } u=ch[u][c]; } val[u]=++cnt; } }; int last[maxn],f[maxn]; void print(int j){ if(j){ ++ans[val[j]]; print(last[j]); } } int getFail(){ queue<int>q; f[0] = 0; for(int c=0;c<SIGMA_SIZE;c++){ int u = ch[0][c]; if(u) { f[u]=0; q.push(u); last[u] = 0;} } while(!q.empty()){ int r=q.front(); q.pop(); for(int c=0;c<SIGMA_SIZE;++c){ int u=ch[r][c]; if(!u){ ch[r][c]=ch[f[r]][c]; continue; } q.push(u); int v=f[r]; while(v&&!ch[v][c]) v=f[v]; f[u] = ch[v][c]; last[u] = val[f[u]]?f[u]:last[f[u]]; } } } void find_T(char* T){ int n=strlen(T); int j=0; for(int i=0;i<n;i++){ int c=idx(T[i]); if(c==-1) { j=0; continue; } j=ch[j][c]; if(val[j]) print(j); else if(last[j]) print(last[j]); } } char tmp[1005][55]; char text[2000005]; int main(){ while(~scanf("%d",&n)) { Trie trie; cnt=0; for(int i=1;i<=n;i++){ scanf("%s",tmp[i]); trie.insert(tmp[i]); } getFail(); scanf("%s",text); find_T(text); for(int i=1;i<=n;++i) { if(ans[i]>0) { printf("%s: %d\n",tmp[i],ans[i]); ans[i]=0; } } } return 0; }