Rabin-Karp — Rolling Hash for Pattern Matching and Substring Search

Sanjeev SharmaSanjeev Sharma
2 min read

Advertisement

Rolling Hash Formula

hash(s[l..r]) = s[l]*p^(r-l) + s[l+1]*p^(r-l-1) + ... + s[r]*p^0

Slide right by 1:
hash(s[l+1..r+1]) = (hash(s[l..r]) - s[l]*p^(r-l)) * p + s[r+1]

Using modular arithmetic to avoid overflow.


Solutions

Python

def rabin_karp(text, pattern):
    MOD = 10**9+7
    BASE = 31
    n, m = len(text), len(pattern)
    if m > n: return []

    def char_val(c): return ord(c)-ord('a')+1

    p_hash = t_hash = 0
    power = 1
    for i in range(m):
        p_hash = (p_hash*BASE + char_val(pattern[i])) % MOD
        t_hash = (t_hash*BASE + char_val(text[i])) % MOD
        if i < m-1: power = power*BASE % MOD

    results = []
    if p_hash == t_hash and text[:m] == pattern:
        results.append(0)

    for i in range(1, n-m+1):
        t_hash = (t_hash - char_val(text[i-1])*power) % MOD
        t_hash = (t_hash*BASE + char_val(text[i+m-1])) % MOD
        t_hash = (t_hash+MOD) % MOD
        if t_hash == p_hash and text[i:i+m] == pattern:
            results.append(i)
    return results

JavaScript

function rabinKarp(text,pat){
    const MOD=1e9+7,BASE=31,n=text.length,m=pat.length;
    const cv=c=>c.charCodeAt(0)-96;
    let ph=0,th=0,pw=1;
    for(let i=0;i<m;i++){ph=(ph*BASE+cv(pat[i]))%MOD;th=(th*BASE+cv(text[i]))%MOD;if(i<m-1)pw=pw*BASE%MOD;}
    const res=[];
    if(ph===th&&text.slice(0,m)===pat)res.push(0);
    for(let i=1;i<=n-m;i++){th=((th-cv(text[i-1])*pw%MOD+MOD)*BASE+cv(text[i+m-1]))%MOD;if(ph===th&&text.slice(i,i+m)===pat)res.push(i);}
    return res;
}

Java

import java.util.*;
public List<Integer> rabinKarp(String text,String pat){
    long MOD=1_000_000_007L,BASE=31,n=text.length(),m=pat.length(),pw=1;
    long ph=0,th=0;
    for(int i=0;i<m;i++){ph=(ph*BASE+pat.charAt(i)-'a'+1)%MOD;th=(th*BASE+text.charAt(i)-'a'+1)%MOD;if(i<m-1)pw=pw*BASE%MOD;}
    List<Integer>res=new ArrayList<>();
    if(ph==th&&text.substring(0,(int)m).equals(pat))res.add(0);
    for(int i=1;i<=n-m;i++){th=((th-(text.charAt(i-1)-'a'+1)*pw%MOD+MOD)*BASE+text.charAt((int)(i+m-1))-'a'+1)%MOD;if(ph==th&&text.substring(i,(int)(i+m)).equals(pat))res.add(i);}
    return res;
}

C++

#include <vector>
#include <string>
using namespace std;
vector<int> rabinKarp(string text,string pat){
    long long MOD=1e9+7,BASE=31,n=text.size(),m=pat.size(),pw=1,ph=0,th=0;
    for(int i=0;i<m;i++){ph=(ph*BASE+pat[i]-'a'+1)%MOD;th=(th*BASE+text[i]-'a'+1)%MOD;if(i<m-1)pw=pw*BASE%MOD;}
    vector<int>res;
    if(ph==th&&text.substr(0,m)==pat)res.push_back(0);
    for(int i=1;i<=n-m;i++){th=((th-(text[i-1]-'a'+1)*pw%MOD+MOD)*BASE+text[i+m-1]-'a'+1)%MOD;if(ph==th&&text.substr(i,m)==pat)res.push_back(i);}
    return res;
}

C

/* C: same rolling hash with long long arithmetic */

Application: Repeated DNA Sequences

def findRepeatedDnaSequences(s):
    MOD = 10**9+7
    BASE = 4
    mapping = {'A':1,'C':2,'G':3,'T':4}
    seen = set()
    res = set()
    h = 0; pw = 4**9 % MOD
    for i in range(len(s)):
        h = (h*BASE + mapping[s[i]]) % MOD
        if i >= 10:
            h = (h - mapping[s[i-10]]*pw) % MOD
            h = (h+MOD) % MOD
        if i >= 9:
            if h in seen: res.add(s[i-9:i+1])
            seen.add(h)
    return list(res)

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro