题意:
给出 m 个只由 A,T,C,G 组成的字符串(每个的长度不大于 10),求一个长度为 n 且也只由 A,T,C,G 组成但不包含之前所给的 m个字符串的序列的种类数。
思路:
借助离散数学中可达矩阵的概念。先把m个字符串建立成AC自动机中的字典树,同时对于字符串的结尾节点,进行标记。一个节点要想一步到达另的一个节点但不形成要求不能出现的字符串,那么这两个节点都不能是要求字符串的结尾。还有一种情况就是字符串之间的相互包含问题,此时就要借助于 fail 指针来处理。因为fail指针会把有相同后缀的节点连起来,那么当一个节点不是字符串的结尾节点,但是其fail指针所指向的节点是,那么该节点同样不能满足要求。这样把不满足要求的节点所在的行和列从矩阵中去掉后,就形成了满足要求的一步可达矩阵,用矩阵快速幂就可以求出n步可达矩阵,对第一行求和就是答案。
代码:
#include <cstdio>
#include <cstring>
#include <queue>
#include <string>
using namespace std;
typedef long long ll;
const int mod=100000;
queue<int>que;
char ss[15];
int trie[105][4],fail[105];
bool flag[105];
int cnt;
struct matrix
{
ll mat[105][105];
void clc()
{
for(int i=0;i<=cnt;i++)
for(int j=0;j<=cnt;j++)
mat[i][j]=0;
}
matrix operator *(const matrix b)const
{
matrix res;
res.clc();
for(int i=0;i<=cnt;i++)
{
for(int k=0;k<=cnt;k++)
{
if(mat[i][k]>0)
{
for(int j=0;j<=cnt;j++)
res.mat[i][j]=(res.mat[i][j]+mat[i][k]*b.mat[k][j]%mod)%mod;
}
}
}
return res;
}
};
int id(char ch)
{
if(ch=='A') return 0;
else if(ch=='C') return 1;
else if(ch=='T') return 2;
else if(ch=='G') return 3;
}
void build(char s[])
{
int p=0,len=strlen(s);
for(int i=0;i<len;i++)
{
int t=id(s[i]);
if(trie[p][t]==0)
trie[p][t]=++cnt;
p=trie[p][t];
}
flag[p]=1;
}
void bfs()
{
while(!que.empty())
que.pop();
for(int i=0;i<4;i++)
{
if(trie[0][i])
{
fail[trie[0][i]]=0;
que.push(trie[0][i]);
}
}
while(!que.empty())
{
int now=que.front();
que.pop();
if(fail[now]&&flag[fail[now]])//包含情况
flag[now]=1;
for(int i=0;i<4;i++)
{
if(trie[now][i])
{
fail[trie[now][i]]=trie[fail[now]][i];
que.push(trie[now][i]);
}
else
trie[now][i]=trie[fail[now]][i];
}
}
}
void init(matrix &a)
{
a.clc();
for(int i=0;i<=cnt;i++)
{
for(int j=0;j<4;j++)
{
if(!flag[i]&&!flag[trie[i][j]])//计数
a.mat[i][trie[i][j]]++;
}
}
}
matrix mpow(matrix a,int b)
{
matrix res;
res.clc();
for(int i=0;i<=cnt;i++)
res.mat[i][i]=1;
while(b)
{
if(b&1)
res=res*a;
a=a*a;
b>>=1;
}
return res;
}
int main()
{
int m,n;
while(scanf("%d%d",&m,&n)!=EOF)
{
cnt=0;
memset(flag,0,sizeof(flag));
memset(trie,0,sizeof(trie));
for(int i=1;i<=m;i++)
{
scanf("%s",ss);
build(ss);
}
bfs();
matrix A;
init(A);
matrix res=mpow(A,n);
ll ans=0;
for(int i=0;i<=cnt;i++)
ans=(ans+res.mat[0][i])%mod;
printf("%lld\n",ans);
}
return 0;
}