「LeetCode」快速幂

2020-04-23
摘要: 本文讨论了快速幂及其简单应用。

快速幂的定义#

快速幂,二进制取幂(Binary Exponentiation,也称平方法),是一个在 $\Theta(\log(n))$ 的时间内计算 $a^n$ 的小技巧,而暴力的计算需要 $\Theta(n)$ 的时间。而这个技巧也常常用在非计算的场景,因为它可以应用在任何具有结合律的运算中。其中显然的是它可以应用于模意义下取幂矩阵幂等运算,我们接下来会讨论。

算法#

计算 $a$ 的 $n$ 次方表示将 $n$ 个 $a$ 乘在一起:$a^{n} =a \times a \cdots \times a$ 。然而当 $a,n$ 太大的时侯,这种方法就不太适用了。不过我们知道: $a^{b+c} = a^b \cdot a^c,a^{2b} = a^b \cdot a^b = (a^b)^2$ 。二进制取幂的想法是,我们将取幂的任务按照指数的 二进制表示 来分割成更小的任务。

首先我们将 $n$ 表示为 2 进制,举一个例子:

$$ 3^{13} = 3^{(1101)_2} = 3^8 \cdot 3^4 \cdot 3^1 $$

因为 $n$ 有 $\lfloor \log_2 n \rfloor + 1$ 个二进制位,因此当我们知道了 $a^1, a^2, a^4, a^8, \dots, a^{2^{\lfloor \log_2 n \rfloor}}$ 后,我们只用计算 $\Theta(\log n)$ 次乘法就可以计算出 $a^n$ 。

于是我们只需要知道一个快速的方法来计算上述 3 的 $2^k$ 次幂的序列。这个问题很简单,因为序列中(除第一个)任意一个元素就是其前一个元素的平方。举一个例子:

$$ \begin{align} 3^1 &= 3 \\ 3^2 &= \left(3^1\right)^2 = 3^2 = 9 \\ 3^4 &= \left(3^2\right)^2 = 9^2 = 81 \\ 3^8 &= \left(3^4\right)^2 = 81^2 = 6561 \end{align} $$

因此为了计算 $3^{13}$ ,我们只需要将对应二进制位为 1 的整系数幂乘起来就行了:

$$ 3^{13} = 6561 \cdot 81 \cdot 3 = 1594323 $$

将上述过程说得形式化一些,如果把 $n$ 写作二进制为 $(n_tn_{t-1}\cdots n_1n_0)_2$ ,那么有:

$$ n = n_t2^t + n_{t-1}2^{t-1} + n_{t-2}2^{t-2} + \cdots + n_12^1 + n_02^0 $$

其中 $n_i\in{0,1}$ 。那么就有

$$ \begin{aligned} a^n & = (a^{n_t 2^t + \cdots + n_0 2^0})\\ & = a^{n_0 2^0} \times a^{n_1 2^1}\times \cdots \times a^{n_t2^t} \end{aligned} $$

根据上式我们发现,原问题被我们转化成了形式相同的子问题的乘积,并且我们可以在常数时间内从 $2^i$ 项推出 $2^{i+1}$ 项。

这个算法的复杂度是 $\Theta(\log n)$ 的,我们计算了 $\Theta(\log n)$ 个 $2^k$ 次幂的数,然后花费 $\Theta(\log n)$ 的时间选择二进制为 1 对应的幂来相乘。

示例代码#

一般函数名取为 bin_pow 或者quick_pow。 前者意为二进制取幂,后者意为快速幂。

第一种实现方法是递归形式的。

第二种实现方法是非递归式的。它在循环的过程中将二进制位为 $1$ 时对应的幂累乘到答案中。尽管两者的理论复杂度是相同的,但第二种在实践过程中的速度是比第一种更快的,因为递归会花费一定的开销。

第三个函数是用于计算 $x^n\bmod m$ 时使用的。

long long quick_pow_res(long long a, long long n) {
    if (n == 0) return 1;
    long long res = quick_pow_res(a, n / 2);
    if (n % 2)
        return res * res * a;
    else
        return res * res;
}

long long quick_pow(long long a, long long n) {
    long long res = 1;
    while(n > 0) {
        if (n & 1) res = res * a;
        a = a * a;
        n >>= 1;
    }

    return res;
}

long long quick_pow_mod(long long a, long long n, long long m) {
    a %= m;
    long long res = 1;
    while (n > 0) {
        if (n & 1) res = res * a % m;
        a = a * a % m;
        n >>= 1;
    }
  
    // if n == 0, m == 1, we should return 1 % 1 = 0.
  return res % m;
}

应用#

矩阵幂#

除了用于数字的直接计算,快速幂可以应用在任何支持结合律的运算中。我们考虑斐波那契数列的计算:

$$f(n)=f(n-1)+f(n-2)$$

如果写成矩阵,那就是:

$$[f(n-2),f(n-1)]\times\left[ \begin{matrix} 0 & 1 \\ 1 & 1 \end{matrix} \right]=[f(n-1),f(n-2)+f(n-1)]=[f(n-1),f(n)]$$

如果把矩阵 $\left[ \begin{matrix} 0 & 1 \\ 1 & 1 \end{matrix} \right]$ 记作 $A$,则有:

$$[f(1),f(2)]\times A^{n-2}=[f(n-1),f(n)]$$

所以,我们可以对中间 $A^{n-2}$ 的运算运用快速幂。我们先给出矩阵运算(加法、乘法)的写法:

typedef long long ll;
typedef vector<ll> vl;
typedef vector<vector<ll>> vvl;


vvl MatrixPlus(vvl& A, vvl& B) {
  int size = A.size();
  vvl res = vvl(size, vl(size, 0));

  for (int i = 0; i < size; ++i)
    for (int j = 0; j < size; ++j)
      // res[i][j] = (A[i][j] + B[i][j]) % p;
      res[i][j] = A[i][j] + B[i][j];
  return res;
}

vvl MatrixMultiply(vvl& A, vvl& B) {
  int a = A.size(), b = B.size(), c = B[0].size();
  vvl C(a, vl(c, 0));
  for(int i = 0; i < a; i++) {
      for(int j = 0; j < c; j++) {
          for(int k = 0; k < b; k++) {
            // (C[i][j] += A[i][k] * B[k][j]) %= p;
            C[i][j] += A[i][k] * B[k][j];
          }
      }
  }

  return C;
}

再给出快速矩阵幂(注意,我们需要先构造一个单位矩阵):

vvl MatrixPow(vvl& m, int n) {
  int size = m.size();
  vvl res = vvl(size, vl(size, 0));
  for(int i = 0; i < size; i++) res[i][i] = 1; // 构造单位矩阵

  while(n > 0) {
      if(n & 1) {
          res = MatrixMultiply(res, m);
      }
      m = MatrixMultiply(m, m);
      n >>= 1;
  }

  return res;
}

例题#

Luogu P1226 快速幂||取余运算#

Luogu P1226

#include <iostream>
#include <stdio.h>

using namespace std;

long long quick_pow_mod(long long a, long long n, long long m) {
    a %= m;
    long long res = 1;
    while(n != 0) {
        if(n & 1) res = res * a % m;
        a = a * a % m;
        n >>= 1;
    }
    
    // if n == 0, m == 1, we should return 1 % 1 = 0.
    return res % m;
}

int main() {
    ios::sync_with_stdio(false);
    
    long long b, p, q;
    cin >> b >> p >> q;
    printf("%lld^%lld mod %lld=%lld", b, p, q, quick_pow_mod(b, p, q));
    return 0;
}

LeetCode 50. Pow(x, n)#

LeetCode 50. Pow(x, n)

注意我们需要用 long long 存储 n,因为 -n 会溢出。

class Solution {
public:
    double quick_pow(double a, long long n) {
        double res = 1;
        while(n != 0) {
            if(n & 1) res = res * a;
            a = a * a;
            n >>= 1;
        }

        return res;
    }
  
    double myPow(double x, int n) {
        long long N = n;
        if(n < 0) {
            N = -N;
            x = 1/x;
        }
        return quick_pow(x, N);
    }
};

LeetCode 372. 超级次方#

LeetCode 372.超级次方

这题只是把指数以数组形式读入(因为它可能很大)。我们只需要做一个简单的分解就可以找到思路:

我们只需要从数组的最后向前遍历,同时记录当前的权重 $\mathrm{weight}=8^1, 8^{10}, 8^{100}$(这个过程可以用快速幂),再计算一次 $\mathrm{weight}^{\mathrm{index[}i]}$ 即可。由于所有操作都是乘法,因此每次操作都对 $m$ 取模即可。

class Solution {
public:
    int quick_pow(int a, int n, int m) {
        a %= m;
        int res = 1;
        while(n != 0) {
            if(n & 1) res = res * a % m;
            a = a * a % m;
            n >>= 1;
        }

        return res % m;
    }

    int superPow(int a, vector<int>& b) {
        int m = 1337, n = b.size(), res = 1;
        for(int i = n-1; i >=0; --i) {
            res = res * quick_pow(a, b[i], m) % m;
            a = quick_pow(a, 10, m);
        }

        return res;
    }
};

LeetCode 面试题08.01 三步问题#

面试题 08.01. 三步问题

如果只是线性的动态规划,我们需要 $O(n)$ 的时间,但是用矩阵幂就是 $O(\log(n))$ 的时间。

递推方程:

$$f(n)=f(n-3)+f(n-2)+f(n-1)$$

于是:

$$[f(n-3),f(n-2),f(n-1)]\times\left[ \begin{matrix} 0 & 0 & 1 \\ 1 & 0 & 1 \\ 0 & 1 & 1\end{matrix}\right]=[f(n-2),f(n-1),f(n)]$$

同样地,我们把矩阵记作 $A$,于是有:

$$[f(1),f(2),f(3)]\times A^{n-3}=[f(n-2),f(n-1),f(n)]$$

typedef long long ll;
typedef vector<ll> vl;
typedef vector<vector<ll>> vvl;
const ll p = 1e9 + 7;

class Solution {
public:
    vvl MatrixMultiply(vvl& A, vvl& B) {
        int a = A.size(), b = A[0].size(), c = B[0].size();
        vvl C(a, vl(c, 0));
        for(int i = 0; i < a; i++) {
            for(int j = 0; j < c; j++) {
                for(int k = 0; k < b; k++) {
                    (C[i][j] += A[i][k] * B[k][j]) %= p;
                }
            }
        }

        return C;
    }

    vvl MatrixPow(vvl& m, int n) {
        int size = m.size();
        vvl res = vvl(size, vl(size, 0));
        for(int i = 0; i < size; i++) res[i][i] = 1; // 构造单位矩阵

        while(n > 0) {
            if(n & 1) {
                res = MatrixMultiply(res, m);
            }
            m = MatrixMultiply(m, m);
            n >>= 1;
        }

        return res;
    }

    int waysToStep(int n) {
        vl f = {1, 2, 4};
        if(n <= 3) return f[n-1];
        vvl factor = {{0, 0, 1}, {1, 0, 1}, {0, 1, 1}};
        vvl res = MatrixPow(factor, n-3);
        long long ans = 0;
        for(int i = 0; i < 3; i++) {
            (ans += res[i][2] * f[i]) %= p;
        }

        return ans;
    }
};