Matrix Exponentiation: Fibonacci and Linear Recurrences

Sanjeev SharmaSanjeev Sharma
3 min read

Advertisement

Matrix Exponentiation

Transform linear recurrences into matrix form and use binary exponentiation to compute the n-th term in O(k^3 log n).

Fibonacci in O(log n)

Standard Fibonacci: F(n) = F(n-1) + F(n-2)

Matrix form:

[F(n+1)]   [1 1]^n   [F(1)]
[F(n)  ] = [1 0]   * [F(0)]

Python Implementation

MOD = 10**9 + 7

def mat_mul(A, B, mod=MOD):
    n = len(A)
    C = [[0]*n for _ in range(n)]
    for i in range(n):
        for k in range(n):
            if A[i][k] == 0: continue
            for j in range(n):
                C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod
    return C

def mat_pow(M, p, mod=MOD):
    n = len(M)
    result = [[1 if i == j else 0 for j in range(n)] for i in range(n)]
    while p > 0:
        if p & 1: result = mat_mul(result, M, mod)
        M = mat_mul(M, M, mod)
        p >>= 1
    return result

def fibonacci(n, mod=MOD):
    if n <= 1: return n
    M = [[1, 1], [1, 0]]
    R = mat_pow(M, n - 1, mod)
    return R[0][0]

# Tribonacci: T(n) = T(n-1) + T(n-2) + T(n-3)
def tribonacci(n, mod=MOD):
    if n == 0: return 0
    if n <= 2: return 1
    M = [[1,1,1],[1,0,0],[0,1,0]]
    R = mat_pow(M, n - 2, mod)
    return R[0][0]  # starts with [1,1,0]

print(fibonacci(10))   # 55
print(fibonacci(10**18))  # huge n in O(log n)!

C++ Implementation

#include <bits/stdc++.h>
using namespace std;

const long long MOD = 1e9 + 7;
typedef vector<vector<long long>> Matrix;

Matrix multiply(const Matrix& A, const Matrix& B) {
    int n = A.size();
    Matrix C(n, vector<long long>(n, 0));
    for (int i = 0; i < n; i++)
        for (int k = 0; k < n; k++) {
            if (!A[i][k]) continue;
            for (int j = 0; j < n; j++)
                C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % MOD;
        }
    return C;
}

Matrix matpow(Matrix M, long long p) {
    int n = M.size();
    Matrix R(n, vector<long long>(n, 0));
    for (int i = 0; i < n; i++) R[i][i] = 1; // identity
    while (p > 0) {
        if (p & 1) R = multiply(R, M);
        M = multiply(M, M);
        p >>= 1;
    }
    return R;
}

long long fibonacci(long long n) {
    if (n <= 1) return n;
    Matrix M = {{1,1},{1,0}};
    return matpow(M, n-1)[0][0];
}

Java Implementation

public class MatrixExp {
    static final long MOD = 1_000_000_007L;

    static long[][] multiply(long[][] A, long[][] B) {
        int n = A.length;
        long[][] C = new long[n][n];
        for (int i = 0; i < n; i++)
            for (int k = 0; k < n; k++) {
                if (A[i][k] == 0) continue;
                for (int j = 0; j < n; j++)
                    C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % MOD;
            }
        return C;
    }

    static long[][] matpow(long[][] M, long p) {
        int n = M.length;
        long[][] R = new long[n][n];
        for (int i = 0; i < n; i++) R[i][i] = 1;
        while (p > 0) {
            if ((p & 1) == 1) R = multiply(R, M);
            M = multiply(M, M);
            p >>= 1;
        }
        return R;
    }

    static long fibonacci(long n) {
        if (n <= 1) return n;
        long[][] M = {{1,1},{1,0}};
        return matpow(M, n-1)[0][0];
    }
}

LeetCode Problems

  • 509. Fibonacci Number — warm-up
  • 1137. N-th Tribonacci Number — matrix exp
  • 70. Climbing Stairs — same as Fibonacci
  • 935. Knight Dialer — matrix exponentiation for moves

Complexity

  • k x k matrix: O(k^3 log n) per query
  • Fibonacci: 2x2 matrix = O(8 log n) = O(log n)

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro