Skip to content

堆结构常见题


合并 k 个有序链表

题目链接

https://leetcode.cn/problems/merge-k-sorted-lists/description/

思路分析

小根堆实现,小根堆可以维护最后合并的新链表中元素的大小是按升序排列的

初始先让所有链表的头节点入堆

弹出一个节点,作为合并的新链表的头节点,同时将弹出节点的 next 节点入堆,重新调整堆

弹出堆顶元素,挂接到上一次弹出节点的后面,将弹出节点的 next 节点入堆,重新调整堆

循环上述过程,直到堆为空,整个过程结束

代码实现

java
/**
 * Definition for singly-linked list.
 * public class ListNode {
 *     int val;
 *     ListNode next;
 *     ListNode() {}
 *     ListNode(int val) { this.val = val; }
 *     ListNode(int val, ListNode next) { this.val = val; this.next = next; }
 * }
 */
class Solution {
    public ListNode mergeKLists(ListNode[] lists) {
        // 优先队列默认是小根堆
        PriorityQueue<ListNode> heap = new PriorityQueue<>((a, b) -> a.val - b.val);
        for (ListNode h : lists) {
            // 将所有链表的头节点入堆,之后的节点肯定可以通过 next 找到
            if (h != null) {
                heap.add(h);
            }
        }
        // 遍历一遍之后,如果发现堆中没有元素,说明 k 个链表都是空链表
        if (heap.isEmpty()) {
            return null;
        }
        // 先弹出一个节点做合并后链表的头节点
        ListNode h = heap.poll();
        // 用 pre 来挂接节点
        ListNode pre = h;
        // 弹出一个节点,将节点的 next 节点加入堆中
        if (pre.next != null) {
            heap.add(pre.next);
        }

        /*
            此时已经有合并的新链表已经有一个节点了,
            开始挂接节点

            只要堆不为空,重复下述过程
            (1)弹出堆顶元素,挂接到新链表中
            (2)将弹出元素的 next 节点元素加入堆中,重新调整堆
         */
        while (!heap.isEmpty()) {
            // 弹出堆顶元素,挂接到新链表中
            ListNode cur = heap.poll();
            pre.next = cur;
            pre = cur;
            // 将弹出元素的 next 节点元素加入堆中,重新调整堆
            if (cur.next != null) {
                heap.add(cur.next);
            }
        }
        // 最后返回新链表的头节点
        return h;
    }
}

线段最多重合问题

题目链接

https://www.nowcoder.com/practice/1ae8d0b6bb4e4bcdbf64ec491f63fc37

最大重叠数指的是在所有区间中,某一时刻最多有多少个区间是重叠的

思路分析

核心要点:一段重合区间的左端点一定是某个线段(区间)的左端点

将区间按照左端点从小到大排序,维护一个小顶堆,存放区间的结束位置,堆的大小就是重叠区间的大小

当新来到一个区间 [x,y] 时,将堆中 <= x 的元素全部弹出,弹出的元素说明以弹出元素作为区间结尾的区间和当前区间是不重合的,将 y 入堆,此时堆中元素的个数就是重合的线段数,因为和当前区间不重合的区间在前面已经弹出了,堆中维护的是重合区间的右端点,不断更新,求最大值

原先已经按照区间的左边界按照升序排序,之前区间的左边界一定在当前区间的左边界之前,现在看是否重合只需要判断之前区间的右边界和当前区间的左边界的位置关系

堆中维护的是区间的结束位置,将堆中 <= x 的元素全部弹出(x 是新区间的左边界),说明之前区间的右边界无法到达当前区间的左边界,也就说明两个区间不是重合的,不会影响最终的结果

将区间按照左边界从小到大排序的意义:当来到一个区间的时候,只需要看该区间之前包括自己有多少个区间和当前区间是重合的,而左区间都是从小到大排序的,也就是说前面区间的左边界一定会在当前区间的左边界之前,如果之前区间的右边界可以超过当前区间的左边界,那就认为区间是重合的

PriorityQueue

java
import java.io.*;
import java.util.*;
// 1:无需package
// 2: 类名必须Main, 不可修改

public class Main {
    public static int MAXN = 10001;

    public static int[][] line = new int[MAXN][2];

    public static int n;

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StreamTokenizer in = new StreamTokenizer(br);
        PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));
        while (in.nextToken() != StreamTokenizer.TT_EOF) {
            n = (int) in.nval;
            for (int i = 0; i < n; i++) {
                in.nextToken();
                line[i][0] = (int) in.nval;
                in.nextToken();
                line[i][1] = (int) in.nval;
            }
            out.println(compute(line));
        }
        out.flush();
        out.close();
        br.close();
    }

    public static int compute(int[][] line){
        int n = line.length;
        Arrays.sort(line,(a,b)-> a[0]- b[0]);
        // 默认是小根堆
        PriorityQueue<Integer> heap = new PriorityQueue<>();
        int ans = 0;
        for (int i = 0; i < n; i++) {
            while (!heap.isEmpty() && heap.peek() <= line[i][0]){
                heap.poll();
            }
            heap.add(line[i][1]);
            ans = Math.max(ans, heap.size());
        }
        return ans;
    }
}

手写堆实现

java
import java.io.*;
import java.util.*;
// 1:无需package
// 2: 类名必须Main, 不可修改

public class Main {
    public static int MAXN = 10001;

    public static int[][] line = new int[MAXN][2];

    public static int n;

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StreamTokenizer in = new StreamTokenizer(br);
        PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));
        while (in.nextToken() != StreamTokenizer.TT_EOF) {
            n = (int) in.nval;
            for (int i = 0; i < n; i++) {
                in.nextToken();
                line[i][0] = (int) in.nval;
                in.nextToken();
                line[i][1] = (int) in.nval;
            }
            out.println(compute());
        }
        out.flush();
        out.close();
        br.close();
    }

    public static int compute() {
        // size 是 static 的,先清空堆
        size = 0;

        // n 条线段,line[0...n-1][2] : line[i][0] line[i][1], 左闭右闭
        // 对 line 数组在 0 - n-1 范围上排序,所有线段按照左区间升序排序
        Arrays.sort(line, 0, n, (a, b) -> a[0] - b[0]);

        int ans = 0;
        for (int i = 0; i < n; i++) {
            // 之前区间的右边界比当前区间的左区间小,区间不重合
            while (size > 0 && heap[0] <= line[i][0]) {
                pop();
            }
            // 将区间的右区间入堆
            add(line[i][1]);
            ans = Math.max(ans, size);
        }
        return ans;
    }

    // 小根堆,堆顶 0 位置
    public static int[] heap = new int[MAXN];

    // 堆的大小
    public static int size;

    public static void add(int x) {
        heap[size] = x;
        /*
         * 分析
         * 原先的堆:1 ..... 5
         *                  i
         *
         * 现在的堆:1 ..... 5 6
         *                  i size
         * 让 i 来到新增加元素的位置
         */
        int i = size++;
        // heapInsert 的过程,构造小根堆
        while (heap[i] < heap[(i - 1) / 2]) {
            swap(i, (i - 1) / 2);
            i = (i - 1) / 2;
        }
    }

    public static void pop() {
        swap(0, size - 1);
        size--;
        // 初始 i = 0,l = 2 * i + 1 = 2 * 0 + 1 = 1
        int i = 0;
        int l = 1;
        // heapify 的过程,需要构造小根堆
        while (l < size) {
            // 找到孩子中最小的那个,记录下标
            int best = l + 1 < size && heap[l + 1] < heap[l] ? l + 1 : l;
            // 最小的孩子跟父节点比较
            best = heap[best] < heap[i] ? best : i;
            // 如果最小的是父节点自己,则无需调整堆
            if (best == i) {
                break;
            }
            swap(i, best);
            i = best;
            l = i * 2 + 1;
        }
    }

    public static void swap(int i, int j) {
        int tmp = heap[i];
        heap[i] = heap[j];
        heap[j] = tmp;
    }
}

复杂度分析

时间复杂度:O(nlogn)

n 个线段,每个线段的右区间进入堆,出堆,堆调整的时间复杂度都是 logn,所以总时间复杂度是 O(nlogn)

空间复杂度:O(n)

堆的大小最多是 n

⭐ 求总重叠区间数

区别于最大重叠数,总重叠区间数指的是所有区间对中,有多少对区间是互相重叠的,这个数是所有重叠对的累计数量

堆中存储的是区间的右端点,同时堆的大小也是重叠的区间数量,做统计即可

java
import java.util.*;

class Solution {
    public int countOverlappingIntervals(int[][] intervals) {
        int n = intervals.length;
        Arrays.sort(intervals, (a, b) -> a[0] - b[0]);  // 按照左端点排序
        PriorityQueue<Integer> heap = new PriorityQueue<>();
        int overlapCount = 0;  // 记录重叠的区间数

        for (int i = 0; i < n; i++) {
            // 这里认为 [1, 5] 和 [5, 8] 是重叠的区间
            // 移除掉所有结束位置小于当前区间左端点的区间
            while (!heap.isEmpty() && heap.peek() < intervals[i][0]) {
                heap.poll();
            }

            // 计算当前区间与堆中的所有区间的重叠
            overlapCount += heap.size();

            // 将当前区间的右端点加入堆
            heap.add(intervals[i][1]);
        }

        return overlapCount;
    }
}

线段重合应用场景

题目链接

https://leetcode.cn/problems/divide-intervals-into-minimum-number-of-groups/

⭐ 思路分析

本题要求的是最少可以分出多少组不重合的区间

求出最大的重合区间数,这里暂且记为 n,而题目要求的是每组不能有重合的区间,即让这些区间去不同的组,也就是说起码(至少)要有 n 组才能使得每组区间都不重合

至少也就是最少的意思,那就可以复用求线段重合问题的思路

注意这里的边界判断,和上题有所不同

代码实现

java
class Solution {
    public int minGroups(int[][] intervals) {
        int n = intervals.length;
        Arrays.sort(intervals, (a, b) -> a[0] - b[0]);
        PriorityQueue<Integer> heap = new PriorityQueue<>();
        int ans = 0;
        for (int i = 0; i < n; i++) {
            // 注意这里的判断,题目说了 [1, 5] 和 [5, 8] 是重叠的区间
            while (!heap.isEmpty() && heap.peek() < intervals[i][0]) {
                heap.poll();
            }
            heap.add(intervals[i][1]);
            ans = Math.max(ans, heap.size());
        }
        return ans;
    }
}

累加和减半的最少操作次数

题目链接

https://leetcode.cn/problems/minimum-operations-to-halve-array-sum/description/

思路分析

本题采用的思路:贪心 + 大根堆

贪心策略:每次都将当前数组中的最大值减半,尽可能让每次减半的价值最大化,则减半总和达到目标时,减半的次数才能相对较小

大根堆的作用:维持数组中最大的元素在堆顶,使得贪心策略成立,每次都取最大的元素减半

总和 / 2 就是需要元素减半总和需要达到的目标值,元素 / 2 就是单次操作减半的价值,做累加,同时计算操作次数,如果达到目标值,累加过程结束,返回操作次数即可

double 实现

java
class Solution {
    public int halveArray(int[] nums) {
        // 大根堆
        PriorityQueue<Double> heap = new PriorityQueue<>((a, b) -> b.compareTo(a));
        double sum = 0;
        for (int num : nums) {
            heap.add((double) num);
            sum += num;
        }
        // 元素减半累加的目标值
        sum /= 2;
        int ans = 0;
        for (double minus = 0, cur; minus < sum; ans++, minus += cur) {
            cur = heap.poll() / 2;
            heap.add(cur);
        }
        return ans;
    }
}

⭐ 非浮点数实现

每个数都乘以 2 的 20 次方,为了是给一个缓冲区,用来除 2,这样就不会出现浮点数

java
class Solution {
    public static int MAXN = 100001;

    public static long[] heap = new long[MAXN];

    public static int size;

    public int halveArray(int[] nums) {
        size = nums.length;
        long sum = 0;
        // 从底到顶建大根堆,时间复杂度 O(n)
        for (int i = size - 1; i >= 0; i--) {
            heap[i] = (long) nums[i] << 20;
            sum += heap[i];
            heapify(i);
        }
        sum /= 2;
        int ans = 0;
        for (long minus = 0; minus < sum; ans++) {
            heap[0] /= 2;
            minus += heap[0];
            // 堆顶元素减半后,需要重新调整堆,维持大根堆的结构
            heapify(0);
        }
        return ans;
    }

    public static void heapify(int i) {
        int l = 2 * i + 1;
        while (l < size) {
            // 大根堆,两个孩子中找最大的
            int best = l + 1 < size && heap[l + 1] > heap[l] ? l + 1 : l;
            // 孩子跟父节点比较一下谁大
            best = heap[best] > heap[i] ? best : i;
            // 最大的就是父节点自己,无需调整堆,退出
            if (best == i) {
                break;
            }
            swap(best,i);
            i = best;
            l = 2 * i + 1;
        }
    }

    public static void swap(int i, int j) {
        long tmp = heap[i];
        heap[i] = heap[j];
        heap[j] = tmp;
    }
}