题目描述

一天小甲苯得到了一条神的指示,他要把神的指示写下来,但是又不能泄露天机,所以他要用一种方法把神的指示记下来。

神的指示是一个字符串,记为字符串 \(s_1\)\(s_1\) 仅包含小写字母 \(\texttt{a-z}\)

现在小甲苯想要写下神的指示,记为字符串 \(s_2\)\(s_2\) 仅包含小写字母 \(\texttt{a-z}\),要求 \(s_1\) 中的相邻的两个字母不能在 \(s_2\) 中相邻地出现。

现在给定 \(s_2\) 的长度,小甲苯想知道他有多少种方法可以将神的指示写下来。

输出种类数对 \(10^9 + 7\) 取模的结果。

输入格式

文件的第一行只有一个正整数 \(n\),代表字符串 \(s_2\) 的长度,\(n \le 10^{15}\)

第二行是一个字符串,代表字符串 \(s_1\)\(s_1\) 的长度不超过 \(100000\)

输出格式

输出一个整数,表示小甲苯可以写出的字符串的总数。

结果对 \(10^9 + 7\) 取模。

样例

样例输入

2
ab

样例输出

675

数据范围与提示

对于 \(30\%\) 的数据,\(n ≤ 100000\)

对于 \(100\%\) 的数据,\(n ≤ 10^{15}\)

题解

我的做法其实应该和网上的题解差不多。不过理解起来可能需要比较感性一点?

因为考虑的只有相邻两个字符,不难得知,可以把所有关系写成一个\(26\times 26\)的矩阵。

先思考一下正常怎么求解:\(f[i][j]=\sum f[i-1][k]*a[k][j]\)\(i\)位第几位数字,\(j\)为当前放的字母,\(k\)为上一位放的字母)

那么我们把\(f[i]\)视为一个\(1\times n\)的矩阵,实际上这就是一个矩阵乘法的过程。

利用矩阵快速幂求解即可。

但是我构造的初始的矩阵有点不同:初始矩阵为一个全\(1\)的矩阵,最后统计答案时答案为\(\sum f[i][i]\)

怎么理解呢?实际上就是第一位怎么放都行,然后矩阵乘法得到的矩阵中的\(c[i][j]\),实际上就是第一个矩阵的第\(i\)行与第二个矩阵的第\(j\)列的答案。所以最终的\(f[i][i]\)即为第一位填\(i\)可以得到的方案数。

那么为什么要这么构造就不难理解了。

#include <bits/stdc++.h>
using namespace std;

#define ll long long
const int N = 100010;
const ll mod = 1e9 + 7;

ll n;
char s[N];
struct mat {
    ll m[26][26];
};

mat operator * (mat a, mat b) {
    mat c;
    memset(c.m, 0, sizeof(c.m));
    for(int i = 0; i < 26; ++i) {
        for(int j = 0; j < 26; ++j) {
            for(int k = 0; k < 26; ++k) {
                c.m[i][j] = (c.m[i][j] + a.m[i][k] * b.m[k][j] % mod) % mod;
            }
        }
    }
    return c;
}

mat base, ans, now;

int idx(char c) {return c - 'a';} 

void print(mat a) {
    for(int i = 0; i < 26; ++i) {
        for(int j = 0; j < 26; ++j) printf("%d ", a.m[i][j]);
        puts("");
    }
}

int main() {
    scanf("%lld%s", &n, s + 1);
    for(int i = 0; i < 26; ++i) for(int j = 0; j < 26; ++j) base.m[i][j] = 1;
    int len = strlen(s + 1);
    for(int i = 1; i < len; ++i) {
        base.m[idx(s[i])][idx(s[i + 1])] = 0;
    }

    memset(ans.m, 0, sizeof(ans.m));
    memset(now.m, 0, sizeof(now.m));
    for(int i = 0; i < 26; ++i) ans.m[i][i] = 1;

    --n;
    while(n) {
        if(n & 1) ans = ans * base;
        base = base * base; n >>= 1;
    }
    for(int i = 0; i < 26; ++i) for(int j = 0; j < 26; ++j) now.m[i][j] = 1;
    now = now * ans;

    ll sum = 0;
    for(int i = 0; i < 26; ++i) sum = (sum + now.m[i][i]) % mod;
    printf("%lld\n", sum);
}