跳转至

矩阵加速递推

题目

斐波那契数列是满足如下性质的一个数列:

\[F_n = \left\{\begin{aligned} 1 \space (n \le 2) \\ F_{n-1}+F_{n-2} \space (n\ge 3) \end{aligned}\right.\]

若现在需要求出 \(F_n \bmod 10^9 + 7\) 的值。

对于 \(100\%\) 的数据,\(1\le n < 2^{63}\)

分析

因为\(1\le n < 2^{63}\),暴力可以直接放弃。

这时候,我们就要引入一种递推加速技术 矩阵快速幂

矩阵快速幂分为矩阵以及快速幂两部分,接下来,我们将分开进行讲解。

快速幂 OI-WIKI

计算 a 的 n 次方表示将 n 个 a 乘在一起:

\[a^{n}=\underset{n}{\underbrace{a\times a\times a\cdots\times a}} \]

整体的时间复杂度在\(\Theta(n)\),当n达到\(10^{18}\)的规模时便会超时

不过我们知道:

\[a^{b+c} = a^b \cdot a^c,\,\,a^{2b} = a^b \cdot a^b = (a^b)^2\]

二进制取幂的想法是,我们将取幂的任务按照指数的二进制表示来分割成更小的任务。

举个例子

首先我们将 n 表示为 2 进制,举个例子:

\[ 3^{13} = 3^{(1101)_2} = 3^8 \cdot 3^4 \cdot 3^1 \]

因为 n 有 \(\lfloor \log_2 n \rfloor + 1\) 个二进制位,因此当我们知道了 \(a^1, a^2, a^4, a^8, \dots, a^{2^{\lfloor \log_2 n \rfloor}}\) 后,我们只用计算 \(\Theta(\log n)\) 次乘法就可以计算出 \(a^n\)

于是我们只需要知道一个快速的方法来计算上述 3 的 2^k 次幂的序列。这个问题很简单,因为序列中(除第一个)任意一个元素就是其前一个元素的平方。举一个例子:

\[3^1 = 3\]
\[3^2 = \left(3^1\right)^2 = 3^2 = 9 \]
\[3^4 = \left(3^2\right)^2 = 9^2 = 81 \]
\[3^8 = \left(3^4\right)^2 = 81^2 = 6561\]

因此为了计算 \(3^{13}\),我们只需要将对应二进制位为 1 的整系数幂乘起来就行了:

\[3^{13} = 6561 \cdot 81 \cdot 3 = 1594323\]

将上述过程说得形式化一些,如果把 n 写作二进制为 \((n_tn_{t-1}\cdots n_1n_0)_ 2\) ,那么有:

\[n = n_t2^t + n_{t-1}2^{t-1} + n_{t-2}2^{t-2} + \cdots + n_12^1 + n_02^0\]

其中 \(n_i\in\{0,1\}\)。那么就有

\[a^n = (a^{n_t 2^t + \cdots + n_0 2^0})\]
\[ \ = a^{n_0 2^0} \times a^{n_1 2^1}\times \cdots \times a^{n_t2^t}\]

根据上式我们发现,原问题被我们转化成了形式相同的子问题的乘积,并且我们可以在常数时间内从 \(2^i\) 项推出 \(2^{i+1}\) 项。

这个算法的复杂度是 \(\Theta(\log n)\) 的,我们计算了 \(\theta(\log n)\)\(2^k\) 次幂的数,然后花费 \(\Theta(\log n)\) 的时间选择二进制为 1 对应的幂来相乘。

矩阵

矩阵乘法

【矩阵乘法】

构造中间矩阵

考虑初始矩阵

\[base=\begin{bmatrix} F_{n-1} \\ F_{n-2}\end{bmatrix}\]
\[ans=\begin{bmatrix} F_n \\ F_{n-1}\end{bmatrix}\]

因为初始矩阵以及新矩阵均为2 by 1的矩阵

可推出中间矩阵为2 by 2的矩阵

\[\because F_n=F_{n-1}+F_{n-2}\]
\[\therefore \begin{bmatrix} F_n \\ F_{n-1}\end{bmatrix}=\begin{bmatrix} F_{n-1}+F_{n-2} \\ F_{n-1}\end{bmatrix}=\begin{bmatrix} F_{n-1}\times1+F_{n-2}\times1 \\ F_{n-1}\times1+F_{n-2}\times0\end{bmatrix}\]

最后推出来的矩阵形式便是中间矩阵与初始\(base\)矩阵相乘

将系数进行提取后可得中间矩阵为

\[\begin{bmatrix}1&1 \\1&0\end{bmatrix}\]

\[\begin{bmatrix} F_n \\ F_{n-1}\end{bmatrix}=\begin{bmatrix} F_{n-1} \\ F_{n-2}\end{bmatrix}\times\begin{bmatrix}1&1 \\1&0\end{bmatrix}\]

矩阵快速幂

在推导出中间矩阵后便可从初始的\(F_1,F_2\)推导出\(F_n\)

\[ \begin{bmatrix} F _n \\ F_{n-1}\end{bmatrix}=\begin{bmatrix} F_{2} \\ F_1\end{bmatrix} \times\underset{n-2}{\underbrace{\begin{bmatrix}1&1 \\1&0\end{bmatrix}\times \cdots \times \begin{bmatrix}1&1 \\1&0\end{bmatrix}}}=\begin{bmatrix} F_2 \\ F_1\end{bmatrix} \times\begin{bmatrix}1&1 \\1&0\end{bmatrix}^{n-2}\]

可以看到,我们成功将其表示成了幂的形式

使用快速幂进行加速即可

Code:

# include<cstdio>
# include<cstring>
using namespace std;
long long n,a[3],mul[3][3],res[3][3],tmp[3][3],tp[3];
void mul_1()
{
    memset(tmp,0,sizeof(tmp));
    for(register int i=1;i<=2;i+=1)
        for(register int j=1;j<=2;j+=1)
            for(register int k=1;k<=2;k+=1)
                tmp[i][j]=(tmp[i][j]+res[i][k]*mul[k][j])%1000000007;
    for(register int i=1;i<=2;i+=1)
        for(register int j=1;j<=2;j+=1)
            res[i][j]=tmp[i][j];
}
void mul_2()
{
    memset(tmp,0,sizeof(tmp));
    for(register int i=1;i<=2;i+=1)
        for(register int j=1;j<=2;j+=1)
            for(register int k=1;k<=2;k+=1)
                tmp[i][j]=(tmp[i][j]+mul[i][k]*mul[k][j])%1000000007;
    for(register int i=1;i<=2;i+=1)
        for(register int j=1;j<=2;j+=1)
            mul[i][j]=tmp[i][j];
}
void solve()
{
    for(register int i=1;i<=2;i+=1)
        for(register int j=1;j<=2;j+=1)
            tp[i]=(tp[i]+res[i][j]*a[j])%1000000007;
    printf("%lld\n",tp[1]);
}
int main()
{
    scanf("%lld",&n);
    if(n<=2)printf("1\n");
    else
    {
        a[1]=a[2]=1;
        for(register int i=1;i<=2;i+=1)
            res[i][i]=1;
        for(register int i=1;i<=2;i+=1)
            for(register int j=1;j<=2;j+=1)
                mul[i][j]=1;
        mul[2][2]=0;
        n-=2;
        while(n)
        {
            if(n&1)mul_1();
            n>>=1;
            mul_2();
        }
        solve();
    }
    return 0;
}