[模板] 矩阵加速递推运算

之前我们已经学习了矩阵乘法与矩阵快速幂,下面我们来学习如何通过矩阵快速幂来优化递推,先从最简单的递推数列开始。

Fibonacci 数列


$$f_0 = 0,f_1 = 1 $$
对于 $i \gt 2$ 有
$$ f_i = f_{i-1} +f_{n+2} $$

如果通过简单的递推求解第 $n$ 项的话,时间复杂度为 $O(n)$

矩阵加速线性递推

我们可以将 Fibonacci 的数列递推过程看做一个矩阵 $
\left[
\begin{matrix}
f_{n-2} & f_{n-1} \\
\end{matrix}
\right]
$ 乘上一个转移矩阵 $A$,得到 $
\left[
\begin{matrix}
f_{n-1} & f_{n} \\
\end{matrix}
\right]
$

通过推导我们可以得出转移矩阵

$$
A =
\left[
\begin{matrix}
0 & 1 \\
1 & 1 \\
\end{matrix}
\right]
$$

我们便可通过矩阵快速幂来加速递推过程,复杂度 $O(2^3 \ log \ n)$

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

using namespace std;
const int MAXN = 10 + 10;

int n, MOD;

struct Matrix{
    int n, m;
    int data[MAXN][MAXN];

    Matrix(int _n, int _m){
        n = _n;
        m = _m;
        memset(data, 0, sizeof(data));
    }

    Matrix operator + (Matrix a){
        Matrix ans(n, m);
        for(register int i=1; i<=n; i++){
            for(register int j=1; j<=m; j++){
                ans.data[i][j] = (data[i][j] + a.data[i][j]) % MOD;
            }
        }

        return ans;
    } 

    Matrix operator * (Matrix a){
        Matrix ans(n, a.m);

        for(register int i=1; i<=n; i++){
            for(register int j=1; j<=a.m; j++){
                long long t = 0;

                for(register int k=1; k<=m; k++){
                    t = (t + ((long long)data[i][k] * a.data[k][j]) % MOD) % MOD; 
                }

                ans.data[i][j] = t;
            }
        }

        return ans;
    }
};

Matrix pow(Matrix a, int n){
    Matrix ans(a.n, a.n), tmp(a.n, a.n);

    for(register int i=1; i<=a.n; i++)
        ans.data[i][i] = 1;

    memcpy(tmp.data, a.data, sizeof(tmp.data));

    while(n){
        if(n & 1)
            ans = ans * tmp;

        n >>= 1;
        tmp = tmp * tmp;
    }

    return ans;
}


int main(){
    scanf("%d%d", &n, &MOD);

    Matrix start(1, 2);
    start.data[1][1] = 0;
    start.data[1][2] = 1;

    Matrix transform(2, 2);
    transform.data[1][1] = 0;
    transform.data[1][2] = 1;
    transform.data[2][1] = 1;
    transform.data[2][2] = 1;

    Matrix ans = start * pow(transform, n);

    printf("%d", ans.data[1][1]);
    return 0;
}