Matrix Exponentiation: Fibonacci and Linear Recurrences
Advertisement
Matrix Exponentiation
Transform linear recurrences into matrix form and use binary exponentiation to compute the n-th term in O(k^3 log n).
Fibonacci in O(log n)
Standard Fibonacci: F(n) = F(n-1) + F(n-2)
Matrix form:
[F(n+1)] [1 1]^n [F(1)]
[F(n) ] = [1 0] * [F(0)]
Python Implementation
MOD = 10**9 + 7
def mat_mul(A, B, mod=MOD):
n = len(A)
C = [[0]*n for _ in range(n)]
for i in range(n):
for k in range(n):
if A[i][k] == 0: continue
for j in range(n):
C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod
return C
def mat_pow(M, p, mod=MOD):
n = len(M)
result = [[1 if i == j else 0 for j in range(n)] for i in range(n)]
while p > 0:
if p & 1: result = mat_mul(result, M, mod)
M = mat_mul(M, M, mod)
p >>= 1
return result
def fibonacci(n, mod=MOD):
if n <= 1: return n
M = [[1, 1], [1, 0]]
R = mat_pow(M, n - 1, mod)
return R[0][0]
# Tribonacci: T(n) = T(n-1) + T(n-2) + T(n-3)
def tribonacci(n, mod=MOD):
if n == 0: return 0
if n <= 2: return 1
M = [[1,1,1],[1,0,0],[0,1,0]]
R = mat_pow(M, n - 2, mod)
return R[0][0] # starts with [1,1,0]
print(fibonacci(10)) # 55
print(fibonacci(10**18)) # huge n in O(log n)!
C++ Implementation
#include <bits/stdc++.h>
using namespace std;
const long long MOD = 1e9 + 7;
typedef vector<vector<long long>> Matrix;
Matrix multiply(const Matrix& A, const Matrix& B) {
int n = A.size();
Matrix C(n, vector<long long>(n, 0));
for (int i = 0; i < n; i++)
for (int k = 0; k < n; k++) {
if (!A[i][k]) continue;
for (int j = 0; j < n; j++)
C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % MOD;
}
return C;
}
Matrix matpow(Matrix M, long long p) {
int n = M.size();
Matrix R(n, vector<long long>(n, 0));
for (int i = 0; i < n; i++) R[i][i] = 1; // identity
while (p > 0) {
if (p & 1) R = multiply(R, M);
M = multiply(M, M);
p >>= 1;
}
return R;
}
long long fibonacci(long long n) {
if (n <= 1) return n;
Matrix M = {{1,1},{1,0}};
return matpow(M, n-1)[0][0];
}
Java Implementation
public class MatrixExp {
static final long MOD = 1_000_000_007L;
static long[][] multiply(long[][] A, long[][] B) {
int n = A.length;
long[][] C = new long[n][n];
for (int i = 0; i < n; i++)
for (int k = 0; k < n; k++) {
if (A[i][k] == 0) continue;
for (int j = 0; j < n; j++)
C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % MOD;
}
return C;
}
static long[][] matpow(long[][] M, long p) {
int n = M.length;
long[][] R = new long[n][n];
for (int i = 0; i < n; i++) R[i][i] = 1;
while (p > 0) {
if ((p & 1) == 1) R = multiply(R, M);
M = multiply(M, M);
p >>= 1;
}
return R;
}
static long fibonacci(long n) {
if (n <= 1) return n;
long[][] M = {{1,1},{1,0}};
return matpow(M, n-1)[0][0];
}
}
LeetCode Problems
- 509. Fibonacci Number — warm-up
- 1137. N-th Tribonacci Number — matrix exp
- 70. Climbing Stairs — same as Fibonacci
- 935. Knight Dialer — matrix exponentiation for moves
Complexity
- k x k matrix: O(k^3 log n) per query
- Fibonacci: 2x2 matrix = O(8 log n) = O(log n)
Advertisement