Segment Trees & BIT — Master Recap & Cheatsheet

Sanjeev SharmaSanjeev Sharma
3 min read

Advertisement

Segment Trees & BIT Master Cheatsheet


BIT (Fenwick Tree) Template

class BIT:
    def __init__(self, n):
        self.n = n
        self.tree = [0] * (n + 1)

    def update(self, i, delta):
        # i is 1-indexed
        while i <= self.n:
            self.tree[i] += delta
            i += i & (-i)

    def query(self, i):
        # prefix sum [1..i]
        s = 0
        while i > 0:
            s += self.tree[i]
            i -= i & (-i)
        return s

    def range_query(self, l, r):
        return self.query(r) - self.query(l - 1)

Segment Tree Template

class SegTree:
    def __init__(self, n):
        self.tree = [0] * (4 * n)
        self.n = n

    def update(self, node, start, end, idx, val):
        if start == end:
            self.tree[node] = val; return
        mid = (start + end) // 2
        if idx <= mid: self.update(2*node, start, mid, idx, val)
        else: self.update(2*node+1, mid+1, end, idx, val)
        self.tree[node] = self.tree[2*node] + self.tree[2*node+1]

    def query(self, node, start, end, l, r):
        if r < start or end < l: return 0
        if l <= start and end <= r: return self.tree[node]
        mid = (start + end) // 2
        return self.query(2*node, start, mid, l, r) +                self.query(2*node+1, mid+1, end, l, r)

2D BIT Template

class BIT2D:
    def __init__(self, m, n):
        self.m, self.n = m, n
        self.tree = [[0]*(n+1) for _ in range(m+1)]

    def update(self, r, c, delta):
        i = r
        while i <= self.m:
            j = c
            while j <= self.n:
                self.tree[i][j] += delta
                j += j & (-j)
            i += i & (-i)

    def query(self, r, c):
        s = 0; i = r
        while i > 0:
            j = c
            while j > 0:
                s += self.tree[i][j]
                j -= j & (-j)
            i -= i & (-i)
        return s

Difference Array Template

diff = [0] * (n + 1)
# Range add [l, r] += val:
diff[l] += val
diff[r+1] -= val
# Reconstruct:
running = 0
for i in range(n):
    running += diff[i]
    arr[i] += running

When to Use What

ScenarioUse
Prefix sums only (no update)Prefix sum array
Point update + prefix sumBIT
Range update + range querySegment Tree + lazy
Min/max range querySegment Tree
2D range sum2D BIT
Range add, get at pointBIT with difference
Count inversionsBIT + coordinate compress

Problem Index

#ProblemKey Technique
01Range Sum Query MutableBIT or Seg Tree
02Count Smaller NumbersBIT + coord compress
03Reverse PairsMerge sort count
04My Calendar ISorted intervals
05Range ModuleDisjoint interval set
07Count of Range SumMerge sort on prefix
08Number of LISDP with seg tree
09Shifting Letters IIDifference array
10Range Sum 2D Immutable2D prefix sum
11My Calendar IIIDifference array max
12Max Sum Rectangle ≤ KBIT + prefix sum
14Product Except SelfPrefix + suffix

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro