Google — Count Inversions (Modified Merge Sort)

Sanjeev SharmaSanjeev Sharma
2 min read

Advertisement

Problem (Google Deep Algo)

Count the number of inversions in an array — pairs (i, j) where i < j but nums[i] > nums[j].

Example:

nums = [2, 4, 1, 3, 5]
Inversions: (2,1),(4,1),(4,3)3

Key Insight — Modified Merge Sort

During merge, when we pick from the right half, all remaining elements in the left half are inversions with it.

Merging [2,4] and [1,3]:
  Pick 1 from right → 2 elements remain in left → count += 2
  Pick 2 from left
  Pick 3 from right → 1 element remains in left → count += 1
  Pick 4 from left
Total additions: 3

Solutions

Python

def countInversions(nums):
    def merge_sort(arr):
        if len(arr) <= 1:
            return arr, 0
        mid = len(arr) // 2
        left, lc = merge_sort(arr[:mid])
        right, rc = merge_sort(arr[mid:])
        merged, mc = merge(left, right)
        return merged, lc + rc + mc

    def merge(left, right):
        result, count = [], 0
        i = j = 0
        while i < len(left) and j < len(right):
            if left[i] <= right[j]:
                result.append(left[i]); i += 1
            else:
                result.append(right[j])
                count += len(left) - i
                j += 1
        result.extend(left[i:])
        result.extend(right[j:])
        return result, count

    _, total = merge_sort(nums)
    return total

JavaScript

function countInversions(nums) {
    let count = 0;
    const ms = arr => {
        if (arr.length<=1) return arr;
        const mid=arr.length>>1, L=ms(arr.slice(0,mid)), R=ms(arr.slice(mid));
        const res=[]; let i=0,j=0;
        while(i<L.length&&j<R.length){
            if(L[i]<=R[j]) res.push(L[i++]);
            else{res.push(R[j++]);count+=L.length-i;}
        }
        return res.concat(L.slice(i)).concat(R.slice(j));
    };
    ms(nums);
    return count;
}

Java

long mergeSort(int[] nums, int[] tmp, int l, int r) {
    if (l>=r) return 0;
    int mid=(l+r)/2;
    long cnt=mergeSort(nums,tmp,l,mid)+mergeSort(nums,tmp,mid+1,r);
    int i=l,j=mid+1,k=l;
    while(i<=mid&&j<=r){
        if(nums[i]<=nums[j]) tmp[k++]=nums[i++];
        else{tmp[k++]=nums[j++];cnt+=mid-i+1;}
    }
    while(i<=mid) tmp[k++]=nums[i++];
    while(j<=r) tmp[k++]=nums[j++];
    System.arraycopy(tmp,l,nums,l,r-l+1);
    return cnt;
}

C++

#include <vector>
using namespace std;
long long mergeCount(vector<int>& a, int l, int r) {
    if (r-l<=1) return 0;
    int mid=(l+r)/2;
    long long cnt=mergeCount(a,l,mid)+mergeCount(a,mid,r);
    vector<int> tmp; int i=l,j=mid;
    while(i<mid&&j<r){if(a[i]<=a[j])tmp.push_back(a[i++]);else{tmp.push_back(a[j++]);cnt+=mid-i;}}
    while(i<mid)tmp.push_back(a[i++]);
    while(j<r)tmp.push_back(a[j++]);
    copy(tmp.begin(),tmp.end(),a.begin()+l);
    return cnt;
}

C

long long merge_count(int* arr, int* tmp, int l, int r) {
    if(r-l<=1) return 0;
    int mid=(l+r)/2; long long cnt=merge_count(arr,tmp,l,mid)+merge_count(arr,tmp,mid,r);
    int i=l,j=mid,k=l;
    while(i<mid&&j<r){if(arr[i]<=arr[j])tmp[k++]=arr[i++];else{tmp[k++]=arr[j++];cnt+=mid-i;}}
    while(i<mid)tmp[k++]=arr[i++]; while(j<r)tmp[k++]=arr[j++];
    for(int x=l;x<r;x++)arr[x]=tmp[x];
    return cnt;
}

Complexity

ApproachTimeSpace
Merge sortO(n log n)O(n)
Naive brute forceO(n²)O(1)

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro