由题意可知,。也就可以可以通过的后面一项和的后面两项推出来。

那么可以得到如下矩阵形式:

这个形式就可以通过矩阵快速幂求出。

import sys
from collections import Counter
from heapq import heappop, heappush
from math import inf, lcm, comb
from typing import List

# 输入加速
input = sys.stdin.readline
"""矩阵快速幂部分"""
def matrix_multiply(A, B, mod=None):
    rows_A = len(A)
    cols_A = len(A[0]) if rows_A > 0 else 0
    rows_B = len(B)
    cols_B = len(B[0]) if rows_B > 0 else 0
    result = [[0 for _ in range(cols_B)] for _ in range(rows_A)]
    for i in range(rows_A):
        for j in range(cols_B):
            total = 0
            for k in range(cols_A):
                total += A[i][k] * B[k][j]
                if mod is not None:
                    total %= mod
            result[i][j] = total % mod if mod is not None else total
    return result
def matrix_power(matrix, power, mod=None):
    n = len(matrix)
    if power == 0:
        return [[1 % mod if i == j else 0 % mod for j in range(n)]
                for i in range(n)] if mod is not None else \
            [[1 if i == j else 0 for j in range(n)] for i in range(n)]
    if power == 1:
        if mod is not None:
            return [[matrix[i][j] % mod for j in range(n)] for i in range(n)]
        return [row[:] for row in matrix]
    result = [[1 % mod if i == j else 0 % mod for j in range(n)]
              for i in range(n)] if mod is not None else \
        [[1 if i == j else 0 for j in range(n)] for i in range(n)]
    base = matrix
    if mod is not None:
        base = [[base[i][j] % mod for j in range(n)] for i in range(n)]
    while power > 0:
        if power % 2 == 1:
            result = matrix_multiply(result, base, mod)
        base = matrix_multiply(base, base, mod)
        power //= 2
    return result
 """矩阵快速幂部分结束。"""
if __name__ == '__main__':
    n,m = map(int,input().split())
    S = [0,1,2]
    if n <= 2:
        print(S[n] %  m)
        exit(0)
    start = [[2,1,1]]
    mat = [
        [1,0,0],
        [1,1,1],
        [1,1,0]
           ]
    res = matrix_multiply(start,matrix_power(mat,n - 2,m),m)
    print(res[0][0])