KMP算法

我的腿让我停下,可是心却不允许我那么做。

  • 引言
    字符串匹配。给你两个字符串,寻找其中一个字符串是否包含另一个字符串,如果包含,返回包含的起始位置.
char *str = "bacbababadababacambabacaddababacasdsd";  
char *ptr = "ababaca";
  • 暴力解法

如果当前字符匹配成功(即S[i] == P[j]),则i++,j++,继续匹配下一个字符;
如果失配(即S[i]! = P[j]),令i = i - (j - 1),j = 0。相当于每次匹配失败时,i 回溯,j 被置为0。

来看看时间复杂度:最坏情况下为O(n*m)
所以有没有一种改进的算法

  • 改进方法
    可以实现复杂度为O(m+n),为何简化了时间复杂度:
    KMP算法主要是取消了指针的回溯,充分利用了目标字符串ptr的性质(比如里面部分字符串的重复性,即使不存在重复字段,在比较时,实现最大的移动量)。每趟匹配过程中出现字符比较不等时,不回溯主指针i,利用已得到的“部分匹配”结果将模式向右滑动尽可能远的一段距离,继续进行比较。
    具体概念上的我不想深究,从代码开始理解

代码

  • KmpSearch函数

    • 假设现在文本串S匹配到 i 位置,模式串P匹配到 j 位置
      • 如果j = -1,或者当前字符匹配成功(即S[i] == P[j]),都令i++,j++,继续匹配下一个字符;
      • 如果j != -1,且当前字符匹配失败(即S[i] != P[j]),则令 i 不变,j = next[j]。此举意味着失配时,模式串P相对于文本串S向右移动了j - next [j] 位。
        换言之,当匹配失败时,模式串向右移动的位数为:失配字符所在位置 - 失配字符对应的next 值,即移动的实际位数为:j - next[j],且此值大于等于1。
  • 所以next 数组各值的含义:代表当前字符之前的字符串中,有多大长度的相同前缀后缀。例如如果next [j] = k,代表j 之前的字符串中有最大长度为k 的相同前缀后缀。
    此也意味着在某个字符失配时,该字符对应的next 值会告诉你下一步匹配中,模式串应该跳到哪个位置(跳到next [j] 的位置)。如果next [j] 等于0或-1,则跳到模式串的开头字符,若next [j] = k 且 k > 0,代表下次匹配跳到j 之前的某个字符,而不是跳到开头,且具体跳过了k 个字符。

int KmpSearch(char* s, char* p)  
{  
    int i = 0;  
    int j = 0;  
    int sLen = strlen(s);  
    int pLen = strlen(p);  
    while (i < sLen && j < pLen)  
    {  
        //①如果j = -1,或者当前字符匹配成功(即S[i] == P[j]),都令i++,j++      
        if (j == -1 || s[i] == p[j])  
        {  
            i++;  
            j++;  
        }  
        else  
        {  
            //②如果j != -1,且当前字符匹配失败(即S[i] != P[j]),则令 i 不变,j = next[j]      
            //next[j]即为j所对应的next值        
            j = next[j];  
        }  
    }  
    if (j == pLen)  
        return i - j;  
    else  
        return -1;  
}  

  • getnext()函数
    • (1) next[0] = -1;
    • (2) 设next[j] = k,则next[j+1] = ?
      令j=j+1,k=k+1;
      • 若pk=pj,则有“p1…pk-1pk”=“pj-k+1…pj-1pj” ,
        next[j]=k;
      • 若pk+1≠pj+1,可把求next值问题看成是一个模式匹配问题,整个模式串既是主串,又是子串。
        即使得k=next[k],回溯求得最长前缀等于最长后缀的下标
j 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
模式串 0 a b c a a b b c a b c a a b d a
next[j] -1 0 0 0 1 1 2 0 0 1 2 3 4 5 6 0 1
void getnext(int*next,char*ctr){
    next[0]=-1;
    int j=0,k=-1,len=strlen(ctr);
    while(j<len){
        if(k==-1||ctr[j]==ctr[k]){
            j++;
            k++;
            next[j]=k;
        }
        else k=next[k];
    }
}

完整代码

#include <bits/stdc++.h>
using namespace std;
const int maxn=1e6+7;
int a[maxn],b[maxn];
int nxt[10005],n,m;
void getnext(){
    int i=0,j=-1;
    nxt[0]=-1;
    while(i<m){
        if(j==-1||b[i]==b[j]){
            i++,j++;
            if(b[i]==b[j])nxt[i]=nxt[j];
            else nxt[i]=j;
        }
        else j=nxt[j];
    }
}
int kmp(){
    int i=0,j=0;
    getnext();
    while(i<n){
        if(a[i]==b[j]||j==-1)i++,j++;
        else j=nxt[j];
        if(j==m)return i-j+1;
    }
    return -1;
}
int main(){
    int t;
    scanf("%d",&t);
    while(t--){
        scanf("%d%d",&n,&m);
        for(int i=0;i<n;i++)scanf("%d",&a[i]);
        for(int i=0;i<m;i++)scanf("%d",&b[i]);
        if(n<m)printf("-1\n");
        else printf("%d\n",kmp());
    }
    return 0;
}

练习

  1. 牛客三 E

题目:给你一个字符串S,你要对字符串S的每一位i将前i位的字符串移动到尾部形成一个新的字符串,如果形成的字符串相同则归为一类Li。现在让你将Li类按照字典序排序,并让你输出每一类的数量和每一类中字符串对应的下标i

很多人用哈希,正解是KMP的next数组的运用.通过观察,题目问的就是求字符串的循环节,而next数组就是前后缀的运用。然后有个结论:如果对于一个长度为L的字符串,如果L%(L-next[L])==0则代表它具有循环节,且循环节的长度为L-next[L]。

#include <bits/stdc++.h>
using namespace std;
const int maxn=1e6+7;
char b[maxn];
int net[maxn],len;
template<class T>
void read(T &res)
{
    res = 0;
    char c = getchar();
    T f = 1;
    while(c < '0' || c > '9')
    {
        if(c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9')
    {
        res = res * 10 + c - '0';
        c = getchar();
    }
    res *= f;
}
template<class T>
void out(T x)
{
    if(x < 0)
    {
        putchar('-');
        x = -x;
    }
    if(x >= 10)
    {
        out(x / 10);
    }
    putchar('0' + x % 10);
}
void get_next(char*ctr)
{
    net[0]=-1;
    int j=0,k=-1;
    len=strlen(ctr);
    while(j<len)
    {
        if(k==-1||ctr[j]==ctr[k])
        {
            j++;
            k++;
            net[j]=k;
        }
        else k=net[k];
    }
}
int main()
{
    scanf("%s",b);
    get_next(b);
    int tmp=len-net[len];
    if(len%tmp==0&&tmp!=len)
    {
        out(tmp);
        puts(" ");
        for(int i=0; i<tmp; i++)
        {
            out(len/tmp);
            printf(" ");
            for(int j=i;j<len;j+=tmp){

                out(j);
                //
                printf(" ");
            }
            printf("\n");
        }
    }
    else{
        out(len);
   // puts(" ");
         printf("\n");
        for(int i=0;i<len;i++){
                out(1);
                printf(" ");
            out(i);
            printf("\n");
        }

    }
    //
    return 0;
}