Merge Sort Algorithm
1. Algorithm Description
Merge Sort is a classic divide-and-conquer sorting algorithm. It works by:
- Divide: Split the array into two roughly equal halves.
- Conquer: Recursively sort each half.
- Merge: Merge the two sorted halves into a single sorted array.
Pseudocode:
function mergeSort(A, left, right):
if left >= right:
return
mid = (left + right) / 2
mergeSort(A, left, mid)
mergeSort(A, mid+1, right)
merge(A, left, mid, right)
Time Complexity
At each level of recursion, we do O(n) work to merge. There are O(log n) levels. Therefore, the overall time complexity is O(n log n).
Memory Complexity
Merge step requires an auxiliary array of size n. The recursion uses O(log n) stack space. Overall additional memory is O(n).
2. Easy Problems
2.1 Problem: Sort an Array of Integers
Given an unsorted array of integers, return it sorted in non-decreasing order.
// Merge Sort for array of ints
#include <bits/stdc++.h>
using namespace std;
void merge(vector<int>& A, int l, int m, int r) {
vector<int> tmp(r - l + 1);
int i = l, j = m + 1, k = 0;
while (i <= m && j <= r) {
tmp[k++] = (A[i] <= A[j] ? A[i++] : A[j++]);
}
while (i <= m) tmp[k++] = A[i++];
while (j <= r) tmp[k++] = A[j++];
for (int p = 0; p < tmp.size(); ++p)
A[l + p] = tmp[p];
}
void mergeSort(vector<int>& A, int l, int r) {
if (l >= r) return;
int m = l + (r - l) / 2;
mergeSort(A, l, m);
mergeSort(A, m + 1, r);
merge(A, l, m, r);
}
int main() {
vector<int> A = {5, 2, 8, 1, 3};
mergeSort(A, 0, A.size() - 1);
for (int x : A) cout << x << " ";
return 0;
}
2.2 Problem: Find the k-th Smallest Element in an Array
Given an array and integer k, return the kth smallest element by first sorting.
// Find k-th smallest by sorting with merge sort
#include <bits/stdc++.h>
using namespace std;
// (reuse merge and mergeSort from above)
int kthSmallest(vector<int>& A, int k) {
mergeSort(A, 0, A.size() - 1);
return A[k - 1];
}
int main() {
vector<int> A = {7, 10, 4, 3, 20, 15};
int k = 3;
cout << "3rd smallest is " << kthSmallest(A, k);
return 0;
}
3. Intermediate Problems
3.1 Problem: Count Inversions in an Array
An inversion is a pair (i, j) with i < j and A[i] > A[j]. Count total inversions in O(n log n).
// Count inversions using modified merge sort
#include <bits/stdc++.h>
using namespace std;
long long mergeCount(vector<int>& A, int l, int m, int r) {
vector<int> tmp;
int i = l, j = m + 1;
long long inv = 0;
while (i <= m && j <= r) {
if (A[i] <= A[j]) {
tmp.push_back(A[i++]);
} else {
tmp.push_back(A[j++]);
inv += (m - i + 1);
}
}
while (i <= m) tmp.push_back(A[i++]);
while (j <= r) tmp.push_back(A[j++]);
for (int k = 0; k < tmp.size(); ++k)
A[l + k] = tmp[k];
return inv;
}
long long sortCount(vector<int>& A, int l, int r) {
if (l >= r) return 0;
int m = l + (r - l) / 2;
long long inv = sortCount(A, l, m);
inv += sortCount(A, m + 1, r);
inv += mergeCount(A, l, m, r);
return inv;
}
int main() {
vector<int> A = {2, 4, 1, 3, 5};
cout << "Inversions: " << sortCount(A, 0, A.size() - 1);
return 0;
}
3.2 Problem: Sort a Singly Linked List
Given the head of a singly linked list, sort it in ascending order in O(n log n) time and O(log n) space.
// Merge Sort on linked list
#include <bits/stdc++.h>
using namespace std;
struct ListNode {
int val;
ListNode* next;
ListNode(int x): val(x), next(nullptr) {}
};
ListNode* mergeList(ListNode* a, ListNode* b) {
ListNode dummy(0), *tail = &dummy;
while (a && b) {
if (a->val <= b->val) { tail->next = a; a = a->next; }
else { tail->next = b; b = b->next; }
tail = tail->next;
}
tail->next = a ? a : b;
return dummy.next;
}
ListNode* sortList(ListNode* head) {
if (!head || !head->next) return head;
// find middle
ListNode *slow = head, *fast = head->next;
while (fast && fast->next) {
slow = slow->next;
fast = fast->next->next;
}
ListNode* mid = slow->next;
slow->next = nullptr;
return mergeList(sortList(head), sortList(mid));
}
int main() {
// Example: 4->2->1->3
ListNode* head = new ListNode(4);
head->next = new ListNode(2);
head->next->next = new ListNode(1);
head->next->next->next = new ListNode(3);
head = sortList(head);
for (ListNode* p = head; p; p = p->next) {
cout << p->val << " ";
}
return 0;
}
4. Hard Problem
4.1 Problem: Merge k Sorted Lists
Given k sorted linked lists, merge them into one sorted list in O(n log k) time by divide-and-conquer (generalized merge sort).
// Merge k sorted lists using divide and conquer
#include <bits/stdc++.h>
using namespace std;
struct ListNode {
int val;
ListNode* next;
ListNode(int x): val(x), next(nullptr) {}
};
ListNode* mergeTwo(ListNode* a, ListNode* b) {
ListNode dummy(0), *tail = &dummy;
while (a && b) {
if (a->val <= b->val) { tail->next = a; a = a->next; }
else { tail->next = b; b = b->next; }
tail = tail->next;
}
tail->next = a ? a : b;
return dummy.next;
}
ListNode* mergeK(vector<ListNode*>& lists, int l, int r) {
if (l > r) return nullptr;
if (l == r) return lists[l];
int m = l + (r - l) / 2;
return mergeTwo(mergeK(lists, l, m), mergeK(lists, m+1, r));
}
ListNode* mergeKLists(vector<ListNode*>& lists) {
return mergeK(lists, 0, lists.size() - 1);
}
int main() {
// Example: [[1->4->5], [1->3->4], [2->6]]
vector<ListNode*> lists;
lists.push_back(new ListNode(1));
lists[0]->next = new ListNode(4);
lists[0]->next->next = new ListNode(5);
lists.push_back(new ListNode(1));
lists[1]->next = new ListNode(3);
lists[1]->next->next = new ListNode(4);
lists.push_back(new ListNode(2));
lists[2]->next = new ListNode(6);
ListNode* head = mergeKLists(lists);
for (ListNode* p = head; p; p = p->next) {
cout << p->val << " ";
}
return 0;
}