Merge Sort Algorithm

1. Algorithm Description

Merge Sort is a classic divide-and-conquer sorting algorithm. It works by:

  1. Divide: Split the array into two roughly equal halves.
  2. Conquer: Recursively sort each half.
  3. Merge: Merge the two sorted halves into a single sorted array.

Pseudocode:


function mergeSort(A, left, right):
    if left >= right:
        return
    mid = (left + right) / 2
    mergeSort(A, left, mid)
    mergeSort(A, mid+1, right)
    merge(A, left, mid, right)
      

Time Complexity

At each level of recursion, we do O(n) work to merge. There are O(log n) levels. Therefore, the overall time complexity is O(n log n).

Memory Complexity

Merge step requires an auxiliary array of size n. The recursion uses O(log n) stack space. Overall additional memory is O(n).

2. Easy Problems

2.1 Problem: Sort an Array of Integers

Given an unsorted array of integers, return it sorted in non-decreasing order.


// Merge Sort for array of ints
#include <bits/stdc++.h>
using namespace std;

void merge(vector<int>& A, int l, int m, int r) {
    vector<int> tmp(r - l + 1);
    int i = l, j = m + 1, k = 0;
    while (i <= m && j <= r) {
        tmp[k++] = (A[i] <= A[j] ? A[i++] : A[j++]);
    }
    while (i <= m) tmp[k++] = A[i++];
    while (j <= r) tmp[k++] = A[j++];
    for (int p = 0; p < tmp.size(); ++p)
        A[l + p] = tmp[p];
}

void mergeSort(vector<int>& A, int l, int r) {
    if (l >= r) return;
    int m = l + (r - l) / 2;
    mergeSort(A, l, m);
    mergeSort(A, m + 1, r);
    merge(A, l, m, r);
}

int main() {
    vector<int> A = {5, 2, 8, 1, 3};
    mergeSort(A, 0, A.size() - 1);
    for (int x : A) cout << x << " ";
    return 0;
}
    

2.2 Problem: Find the k-th Smallest Element in an Array

Given an array and integer k, return the kth smallest element by first sorting.


// Find k-th smallest by sorting with merge sort
#include <bits/stdc++.h>
using namespace std;

// (reuse merge and mergeSort from above)

int kthSmallest(vector<int>& A, int k) {
    mergeSort(A, 0, A.size() - 1);
    return A[k - 1];
}

int main() {
    vector<int> A = {7, 10, 4, 3, 20, 15};
    int k = 3;
    cout << "3rd smallest is " << kthSmallest(A, k);
    return 0;
}
    

3. Intermediate Problems

3.1 Problem: Count Inversions in an Array

An inversion is a pair (i, j) with i < j and A[i] > A[j]. Count total inversions in O(n log n).


// Count inversions using modified merge sort
#include <bits/stdc++.h>
using namespace std;

long long mergeCount(vector<int>& A, int l, int m, int r) {
    vector<int> tmp;
    int i = l, j = m + 1;
    long long inv = 0;
    while (i <= m && j <= r) {
        if (A[i] <= A[j]) {
            tmp.push_back(A[i++]);
        } else {
            tmp.push_back(A[j++]);
            inv += (m - i + 1);
        }
    }
    while (i <= m) tmp.push_back(A[i++]);
    while (j <= r) tmp.push_back(A[j++]);
    for (int k = 0; k < tmp.size(); ++k)
        A[l + k] = tmp[k];
    return inv;
}

long long sortCount(vector<int>& A, int l, int r) {
    if (l >= r) return 0;
    int m = l + (r - l) / 2;
    long long inv = sortCount(A, l, m);
    inv += sortCount(A, m + 1, r);
    inv += mergeCount(A, l, m, r);
    return inv;
}

int main() {
    vector<int> A = {2, 4, 1, 3, 5};
    cout << "Inversions: " << sortCount(A, 0, A.size() - 1);
    return 0;
}
    

3.2 Problem: Sort a Singly Linked List

Given the head of a singly linked list, sort it in ascending order in O(n log n) time and O(log n) space.


// Merge Sort on linked list
#include <bits/stdc++.h>
using namespace std;

struct ListNode {
    int val;
    ListNode* next;
    ListNode(int x): val(x), next(nullptr) {}
};

ListNode* mergeList(ListNode* a, ListNode* b) {
    ListNode dummy(0), *tail = &dummy;
    while (a && b) {
        if (a->val <= b->val) { tail->next = a; a = a->next; }
        else { tail->next = b; b = b->next; }
        tail = tail->next;
    }
    tail->next = a ? a : b;
    return dummy.next;
}

ListNode* sortList(ListNode* head) {
    if (!head || !head->next) return head;
    // find middle
    ListNode *slow = head, *fast = head->next;
    while (fast && fast->next) {
        slow = slow->next;
        fast = fast->next->next;
    }
    ListNode* mid = slow->next;
    slow->next = nullptr;
    return mergeList(sortList(head), sortList(mid));
}

int main() {
    // Example: 4->2->1->3
    ListNode* head = new ListNode(4);
    head->next = new ListNode(2);
    head->next->next = new ListNode(1);
    head->next->next->next = new ListNode(3);

    head = sortList(head);
    for (ListNode* p = head; p; p = p->next) {
        cout << p->val << " ";
    }
    return 0;
}
    

4. Hard Problem

4.1 Problem: Merge k Sorted Lists

Given k sorted linked lists, merge them into one sorted list in O(n log k) time by divide-and-conquer (generalized merge sort).


// Merge k sorted lists using divide and conquer
#include <bits/stdc++.h>
using namespace std;

struct ListNode {
    int val;
    ListNode* next;
    ListNode(int x): val(x), next(nullptr) {}
};

ListNode* mergeTwo(ListNode* a, ListNode* b) {
    ListNode dummy(0), *tail = &dummy;
    while (a && b) {
        if (a->val <= b->val) { tail->next = a; a = a->next; }
        else { tail->next = b; b = b->next; }
        tail = tail->next;
    }
    tail->next = a ? a : b;
    return dummy.next;
}

ListNode* mergeK(vector<ListNode*>& lists, int l, int r) {
    if (l > r) return nullptr;
    if (l == r) return lists[l];
    int m = l + (r - l) / 2;
    return mergeTwo(mergeK(lists, l, m), mergeK(lists, m+1, r));
}

ListNode* mergeKLists(vector<ListNode*>& lists) {
    return mergeK(lists, 0, lists.size() - 1);
}

int main() {
    // Example: [[1->4->5], [1->3->4], [2->6]]
    vector<ListNode*> lists;
    lists.push_back(new ListNode(1));
    lists[0]->next = new ListNode(4);
    lists[0]->next->next = new ListNode(5);
    lists.push_back(new ListNode(1));
    lists[1]->next = new ListNode(3);
    lists[1]->next->next = new ListNode(4);
    lists.push_back(new ListNode(2));
    lists[2]->next = new ListNode(6);

    ListNode* head = mergeKLists(lists);
    for (ListNode* p = head; p; p = p->next) {
        cout << p->val << " ";
    }
    return 0;
}