算法基础

算法基础

框架概括

整体框架

  • 数据结构(增删查改)
    • 数组(顺序存储)
      • 动态数组
      • 字符串
      • 哈希表
    • 链表(链式存储)
      • 单/双链表
  • 算法(穷举)
    • 如何避免遗漏
      • 回溯算法
      • 动态规划
      • DFS
      • BFS
    • 如何避免冗余
      • 二分
      • 滑动窗口
      • 贪心

各类数据结构的遍历

  1. 数组的遍历,线性迭代结构:
    1
    2
    3
    
    def traverse(arr: List[int]):
        for i in range(len(arr)):
            # 迭代访问 arr[i]
    
  2. 链表的遍历,兼具迭代和递归结构:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    
    # 基本的单链表节点
    class ListNode:
        def __init__(self, val):
            self.val = val
            self.next = None
    
    def traverse(head: ListNode) -> None:
        p = head
        while p is not None:
            # 迭代访问 p.val
            p = p.next
    
    def traverse(head: ListNode) -> None:
        # 递归访问 head.val
        traverse(head.next)
    
  3. 二叉树的遍历,典型的非线性递归遍历结构:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    
    # 基本的二叉树节点
    class TreeNode:
        def __init__(self, val=0, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right
    
    def traverse(root: TreeNode):
        traverse(root.left)
        traverse(root.right)
    

算法复杂度

主定理(Master Theorem)

假设有递归关系式:

$T(N) = aT(N/b) + f(N), f(N) = N^{\log_b(a)} \log^k(N)$

其中,$N$为问题规模,$a$为递归的子问题数量,$N/b$为每个子问题的规模(假设每个子问题的规模基本一样),$f(N)$为递归以外进行的计算工作。

则其算法复杂度为

$T(N) = O(N^{\log_b(a)} \log^{(k+1)}N)$

常见算法复杂度

算法 递归关系式 复杂度
二分查找 $T(N) = T(N/2) + O(1)$ $O(\log(N))$
二叉树遍历 $T(N) = 2T(N/2) + O(1)$ $O(N)$
归并排序 $T(N) = 2T(N/2) + O(N)$ $O(N\log(N))$

双指针

使用两个指针变量在数组或链表等线性结构上协同移动,避免嵌套循环,将部分 $O(N^2)$ 的算法优化为 $O(N)$。主要分为:

  • 同向双指针(快慢指针):一个快指针先行,慢指针跟进,常用于滑动窗口(去重)、链表操作(找中点、判断环、环入口)等。
  • 相向双指针(对撞指针):从两端向中间移动,常用于有序数组求和、回文判断、反转数组、数组合并等。
  • 背向双指针:从中间向两边扩展,常用于回文串、最长子回文等问题。

算法复杂度

通常情况下,时间复杂度 $O(N)$(与最内层循环主体的执行次数有关),空间复杂度:$O(1)$。

使用场景

  • 滑动窗口 (90%)
  • 时间复杂度要求 $O(N)$ (80%是双指针)
  • 要求原地操作,只可以使用交换,不能使用额外空间 (80%)
  • 有子数组 subarray / 子字符串 substring 的关键词 (50%)
  • 有回文 Palindrome 关键词(50%)

代码模板

  • 初始化指针:left, right根据方向设置起点
  • 循环控制:whilefor控制移动(比如right扩展,left收缩)
  • 状态更新:维护当前窗口或配对状态,根据条件分类讨论
  • 结果记录:更新答案(相等时、满足条件时)
  • 边界处理:空数组、单元素、去重跳过等
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 通用双指针框架(适用于数组/列表)
def two_pointers(arr):
    n = len(arr)
    if n == 0:
        return 0  # 或其他默认值

    # Step 1: 初始化指针
    left = 0                    # 左指针 / 慢指针
    # right = 0 或 n - 1,根据方向选择

    # Step 2: 根据类型选择遍历结构
    for right in range(n):      # 同向:快慢指针;滑动窗口
    # while left < right:       # 相向:对撞指针(常用于有序数组)
    # while left < n:           # 其他控制条件

        # Step 3: 扩展或移动右指针后,处理当前窗口/状态
        # ... 更新状态

        # Step 4: 判断是否需要收缩左指针(滑动窗口类)
        while left <= right and need_to_move_left(arr, left, right):
            # ... 更新或记录结果
            left += 1

        # 或:根据条件移动双指针(对撞类)
        # if condition:
        #     left += 1
        # else:
        #     right -= 1

    return result

例题

88. 合并两个有序数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""
📖描述:给你两个按 非递减顺序 排列的整数数组 `nums1` 和 `nums2`,另有两个整数 `m` 和 `n`,分别表示 `nums1` 和 `nums2` 中的元素数目。
    请你 合并 `nums2` 到 `nums1` 中,使合并后的数组同样按 非递减顺序 排列。
🧪样例:输入:nums1 = [1,2,3,0,0,0], m = 3, nums2 = [2,5,6], n = 3;输出:[1,2,2,3,5,6]
💡难点:从后往前操作以便直接覆盖。
"""

def merge(nums1: List[int], m: int, nums2: List[int], n: int) -> None:
    """
    Do not return anything, modify nums1 in-place instead.
    """
    # 逆向双指针,从后往前操作可以直接覆盖
    p1, p2 = m - 1, n - 1  # 同向,但是从后往前
    tail = m + n - 1  # 需要维护的状态:当前需要处理的索引
    while True:
        if p1 < 0 or p2 < 0:
            break
        if nums1[p1] <= nums2[p2]:
            nums1[tail] = nums2[p2]
            p2 -= 1
            tail -= 1
        else:
            nums1[tail] = nums1[p1]
            p1 -= 1
            tail -= 1
    # 由于比较,总会有一个数组先结束,对于后结束的一个数组:这里肯定是p2
    if p2 >= 0:
        nums1[: p2 + 1] = nums2[: p2 + 1]

def merge(nums1: List[int], nums2: List[int]) -> List[int]:
    """ 合并双指针,非原地操作。
    🧪样例:输入:nums1 = [1,2,3], nums2 = [2,5,6];输出:[1,2,2,3,5,6]
    """
    m, n = len(num1), len(nums2)
    new_list = []
    i, j = 0, 0
    # 合并的过程只能操作 i, j 的移动,不要去用 list1.pop(0) 之类的操作
    # 因为 pop(0) 是 O(n) 的时间复杂度,而且会改变序号
    while True:
        if i >= m or j >= n:
            break
        if nums[i] < nums[j]:
            new_list.append(nums[i])
            i += 1
        else:
            new_list.append(nums[j])
            j += 1
    # 合并剩下的数到 new_list 里
    while i < m:
        new_list.append(nums[i])
        i += 1
    while j < n:
        new_list.append(nums[j])
        j += 1
    return new_list

21. 合并两个有序链表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Definition for singly-linked list.
# class ListNode:
#     def __init__(self, val=0, next=None):
#         self.val = val
#         self.next = next
def mergeTwoLists(list1: ListNode, list2: ListNode) -> ListNode:
    dummy = ListNode()  # 虚拟头结点,它的唯一作用就是提供一个起始点,让 p 可以不断向后连接节点
    p = dummy  # p 指向虚拟链表的末尾
    p1 = l1
    p2 = l2

    while p1 is not None and p2 is not None:
        # 比较 p1 和 p2 两个指针
        # 将值较小的的节点接到 p 指针
        if p1.val > p2.val:
            p.next = p2
            p2 = p2.next
        else:
            p.next = p1
            p1 = p1.next
        # p 指针不断前进
        p = p.next

    if p1 is not None:
        p.next = p1

    if p2 is not None:
        p.next = p2

    return dummy.next  # 注意:不是返回指针 p,而是返回链表的头部,也就是 dummy.next

虚拟头结点
代码中用到了一个链表的算法题中是很常见的「虚拟头结点」技巧,也就是 dummy 节点。
当遇到需要创造一条新链表的情况,可以使用虚拟头结点简化边界情况的处理。
如果不使用 dummy 虚拟节点,代码会复杂一些,需要额外处理指针 p 为空的情况。而有了 dummy 节点这个占位符,可以避免处理空指针的情况,降低代码的复杂性。


5. 最长回文子串

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
📖描述:给你一个字符串 `s`,找到 `s` 中最长的 回文 子串。
🧪样例:输入s = "babad";输出"bab""aba"。输入:s = "cbbd";输出:"bb"。
💡重点:
- 需要同时考虑奇数和偶数长的回文串
- 中心扩散
- 这题还可以用动态规划解:
    - 状态定义:dp[i][j]表示s[i:j+1]是否为回文
    - 初始化:dp = [[False for _ in range(size)] for _ in range(size)]
    - 转移方程:dp[i][j] = dp[i-1][j-1] and s[i] == s[j]
"""
def longestPalindrome(s: str) -> str:
    n = len(s)
    if n <= 1:
        return s
    max_s, max_len = "", 0
    for i in range(n):
        if n - 1 - i < (max_len - 1) / 2:
            break  # 提前终止
        # 处理奇数长度的回文子串,以i为中心向两边移动
        left, right = i, i
        while True:
            if left < 0 or right > n - 1:
                break
            if s[left] == s[right]:
                left -= 1
                right += 1
            else:
                break  # 注意所有break的情况
        cur_len = right - left - 1
        if cur_len > max_len:
            max_s = s[left + 1 : right]
            max_len = cur_len
        # 处理偶数长度的回文子串
        left, right = i, i + 1
        while True:
            if left < 0 or right > n - 1:
                break
            if s[left] == s[right]:
                left -= 1
                right += 1
            else:
                break
        cur_len = right - left - 1
        if cur_len > max_len:
            max_s = s[left + 1 : right]
            max_len = cur_len
    return max_s

930. 和相同的二元子数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
📖描述:给你一个二元数组 `nums`,和一个整数 `goal`,请你统计并返回有多少个和为 `goal` 的非空子数组。
🧪样例:
    输入:nums = [1,0,1,0,1], goal = 2
    输出:4
    解释:有 4 个满足题目要求的子数组:[1,0,1]、[1,0,1,0]、[0,1,0,1]、[1,0,1]
💡重点:
    1. 可以用前缀和 + 哈希表,类似两数之和
    2. 也可以用滑动窗口,因为元素都是非负的(只有0和1)
"""

# 方法1:前缀和 + 哈希表,时间 O(N),空间 O(N)
def numSubarraysWithSum(nums: List[int], goal: int) -> int:
    from collections import defaultdict
    prefix_sum = defaultdict(int)
    prefix_sum[0] = 1  # 前缀和为0的有1个(空前缀)
    cur_sum = 0
    count = 0
    for num in nums:
        cur_sum += num
        # 需要找之前的前缀和 = cur_sum - goal
        count += prefix_sum[cur_sum - goal]
        prefix_sum[cur_sum] += 1
    return count

# 方法2:滑动窗口,时间 O(N),空间 O(1)
# 由于元素非负,可以利用滑动窗口
# atMost(goal) 返回和 <= goal 的子数组个数
# 答案 = atMost(goal) - atMost(goal - 1)
def numSubarraysWithSum(nums: List[int], goal: int) -> int:
    def atMost(goal):
        if goal < 0:
            return 0
        left = 0
        cur_sum = 0
        count = 0
        for right in range(len(nums)):
            cur_sum += nums[right]
            while cur_sum > goal:
                cur_sum -= nums[left]
                left += 1
            count += right - left + 1  # 以 right 结尾的子数组个数
        return count

    return atMost(goal) - atMost(goal - 1)

滑动窗口

滑动窗口可以归为快慢双指针,一快一慢两个指针前后相随,中间的部分就是窗口。滑动窗口算法技巧主要用来解决子数组问题,比如让你寻找符合某个条件的最长/最短子数组。

与普通的快慢指针(嵌套循环,$O(N^2)$)不同的是,滑动窗口(队列)维护的元素只进入/移出一次(指针 left, right 只增不减),所以复杂度为$O(N)$。算法的重点在于判断是否要把 left 移动。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 滑动窗口模板
def sliding_window(s: str):
    # 用合适的数据结构记录窗口中的数据,根据具体场景变通
    # 比如说,我想记录窗口中元素出现的次数,就用 map
    # 如果我想记录窗口中的元素和,就可以只用一个 int
    window = set()

    left, right = 0, 0
    while right < len(s):
        # c 是将移入窗口的字符
        c = s[right]
        window.add(c)
        # 增大窗口
        right += 1
        # 进行窗口内数据的一系列更新
        ...

        # 判断左侧窗口是否要收缩
        while left < right and window_needs_shrink(s, left, right):
            # 把 s[left] 移出窗口
            window.remove(s[left])
            # 缩小窗口
            left += 1
            # 进行窗口内数据的一系列更新
            ...

基于这个框架,遇到子串/子数组相关的题目,你只需要回答以下三个问题:

  1. 什么时候应该移动 right 扩大窗口?窗口加入字符时,应该更新哪些数据?
  2. 什么时候窗口应该暂停扩大,开始移动 left 缩小窗口?从窗口移出字符时,应该更新哪些数据?
  3. 什么时候应该更新结果?

例题

3. 无重复字符的最长子串

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""
📖描述:给定一个字符串 s ,请你找出其中不含有重复字符的 最长 子串 的长度。
🧪样例:
    输入: s = "abcabcbb"
    输出: 3
    解释: 因为无重复字符的最长子串是 "abc",所以其长度为 3。注意 "bca""cab" 也是正确答案。
💡重点:滑动窗口最重要的是指针只增不减
"""
# 错误用法
# 下面这种写法没有保证每个元素只处理一次(l递增,但r会回退),就是暴力的嵌套循环,复杂度为 O(N^2)
def lengthOfLongestSubstring(s: str) -> int:
    N = len(s)
    ans = 0
    for l in range(N):
        substr = {s[l]}
        for r in range(l + 1, N):
            if s[r] not in substr:
                substr.add(s[r])
            else:
                break
        ans = max(ans, len(substr))
    return ans

# 正确用法 1
# r只增不减:O(N), 23ms
def lengthOfLongestSubstring(s: str) -> int:
    N = len(s)
    right, ans = 0, 0
    for left in range(N):
        substr = s[left:right]
        while True and right < N:
            if s[right] in substr:
                break
            else:
                right += 1
                substr = s[left:right]
        ans = max(ans, len(substr))
    return ans

# 正确用法 2
# window维护每个字符出现的次数, 删除 s[left] 直至 win[s[right]] <= 1, O(N), 76ms
def lengthOfLongestSubstring(s: str) -> int:
    win, ans = dict(), 0
    left, right = 0, 0
    while right < len(s):
        win[s[right]] = win.get(s[right], 0) + 1
        # 滑动左指针直到win[s[right]]<=1
        while win[s[right]] > 1:
            win[s[left]] -= 1
            left += 1
        ans = max(ans, right - left + 1)
        right += 1
    return ans

567. 字符串的排列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
📖描述:给你两个字符串 s1 和 s2 ,写一个函数来判断 s2 是否包含 s1 的 排列。如果是,返回 true ;否则,返回 false 。
🧪样例:
    输入:s1 = "abb" s2 = "eidbabooo"
    输出:true
    解释:s2 包含 s1 的排列之一 ("bab").
💡重点:
"""

# 错误解法
# 暴力枚举,每次循环都重新计算Counter,时间复杂度 O(N^2)
def checkInclusion(s1: str, s2: str) -> bool:
    from collections import Counter

    N, M = len(s1), len(s2)
    ref = Counter(s1)
    # 窗长始终为N
    for i in range(M):
        substr = s2[i: i + N]
        win = Counter(substr)
        if win == ref:
            return True
    return False

# 正确解法
# 不需要每次都重新计算,只需要更新 left 和 right 的计数,时间复杂度 O(N)
def checkInclusion(s1: str, s2: str) -> bool:
    N, M = len(s1), len(s2)
    if N > M:
        return False

    ref = {}
    for c in s1:
        ref[c] = ref.get(c, 0) + 1

    window = {}
    for i in range(N):
        window[s2[i]] = window.get(s2[i], 0) + 1

    right = N
    while right < M:
        # print(window)
        if window == ref:
            return True
        new = s2[right]
        old = s2[right - N]
        # 维护滑动窗
        window[new] = window.get(new, 0) + 1
        window[old] = window.get(old, 0) - 1
        if window[old] == 0:
            del window[old]
        right += 1

    if window == ref:
        return True
    return False

查找

查找是最基础操作,其中最常用的是二分查找,即从有序数组array中直接寻找某个值query对应的index。一般解法:

  • 双指针:比较array[mid]query的大小(mid = low + (high-low)//2),从而更新左右指针lowhigh,终止条件:
    • (1) 找到了queryarray[mid] = query
    • (2) 左右指针相遇(low > high
  • 递归:分成左右两子数组,如果array[mid]不等于query则不断在左或者右子数组里面查找,直到找到了query或者子数组为空。

重点在于分类讨论,建议用双闭区间,仔细讨论array[low:mid], array[mid], array[mid+1:high+1]的情况,并注意三个数组是否为空。

算法复杂度

时间 $O(\log(N))$。每次只需要查一边,所以子问题数量为1。空间 $O(1)$。

使用场景

  • 当数组已经排好序 (30-40%是二分)
  • 当面试官要求你找一个比 $O(N)$ 更小的时间复杂度算法的时候(99%)
  • 找到数组中的一个分割位置,使得左半部分满足某个条件,右半部分不满足(100%)
  • 找到一个最大/最小的值使得某个条件被满足(90%)

代码模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def hash_search(arr, query):
    # 哈希查找,用于无序数组
    seen = {}
    for i, val in enumerate(arr):
        complement = query - val
        if complement in seen:
            return [seen[complement], i]  # 如两数之和
        seen[val] = i
    return -1

def binary_search(array, query):
    """ Two points. [low, high] will be splitted:
        (1) [low, mid - 1]
        (2) [mid]
        (3) [mid + 1, high]
    """
    low, high = 0, len(array) - 1  # 闭区间 [left, right]
    while low <= high:
        mid = low + (high - low) // 2  # 防溢出
        val = array[mid]
        # array[low:mid], array[mid], array[mid+1:high+1]
        if val == query:
            return mid
        if val < query:
            low = mid + 1
        else:
            high = mid - 1
    return None

def binary_search_recur(array, low, high, query):
    """ Recurrence. [low, high] will be splitted:
        (1) [low, mid - 1]
        (2) [mid]
        (3) [mid + 1, high]
    """
    if low > high:
        return -1
    mid = low + (high - low) // 2   # This mid will not break integer range
    if query < array[mid]:
        return binary_search_recur(array, low, mid - 1, query)  # Go search in the left subarray
    if query > array[mid]:
        return binary_search_recur(array, mid + 1, high, query)  # Go search in the right subarray
    return mid  # `array[mid] = query`, stop recurrence

例题

33. 搜索旋转排序数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""
📖描述:给定旋转后的数组 `nums` 和一个整数 `target`,如果 `nums` 中存在这个目标值 `target`,则返回它的下标,否则返回 `-1`。
🧪样例:输入:`nums = [2,3,4,5,6,7,0,1]`, `target = 0`;输出:`target`的下标为`6`。
💡重点:
1. 数组不是有序的,但是是局部有序的。有序的那端一定是最左边小于最右边,无序的那端一定是最左边大于最右边。
2. 目标是否在有序部分比较好判断`nums[left_] <= target and target < nums[right_]`,如果不满足则落在另一边。
<https://leetcode.cn/problems/search-in-rotated-sorted-array/solutions/2636954/javapython3cer-fen-cha-zhao-you-xu-de-ba-5g7e>
"""
def search(nums: List[int], target: int) -> int:
    low, high = 0, len(nums) - 1
    while low <= high:
        mid = low + (high - low) // 2
        val = nums[mid]
        # print(mid, nums[low:mid], nums[mid], nums[mid+1:high+1])
        if val == target:
            return mid
        if low < mid and nums[low] <= nums[mid - 1]:
            # 左边有序,先判断是否在左边
            if nums[low] <= target and target <= nums[mid - 1]:
                high = mid - 1
            else:
                low = mid + 1
        elif mid < high:
            # 右边有序,先判断是否在右边
            if nums[mid + 1] <= target and target <= nums[high]:
                low = mid + 1
            else:
                high = mid - 1
        else:
            return -1
    return -1

658. 找到 K 个最接近的元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""
📖描述:给定一个排序好的数组 `arr`,两个整数 `k` 和 `x`,从数组中找到最靠近 `x` 的 `k` 个数。返回的结果必须要是按升序排好的。
🧪样例:输入:`arr = [1,2,3,4,5]`, `k = 4`, `x = 3`;输出:`[1,2,3,4]`。
💡重点:
1. 反向思维,删除最边缘的`n - k`个,每次判断删最左边还是删最右边。
2. 返回结果要排好序,可以用双指针寻找最优子区间。
"""

def findClosestElements(arr: List[int], k: int, x: int) -> List[int]:
    # 排除法(双指针)
    N = len(arr)
    remove_nums = N - k
    left, right = 0, N - 1
    while remove_nums:
        # 注意:这里等于号的含义,题目中说,差值相等的时候取小的
        # 因此相等的时候,尽量缩小右边界
        if x - arr[left] <= arr[right] - x:
            right -= 1
        else:
            left += 1
        remove_nums -= 1
    return arr[left:left + k]

215. 数组中的第K个最大元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
📖描述:
    给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。
    请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。
    你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。
🧪样例:
    输入: [3,2,1,5,6,4], k = 2
    输出: 5

    输入: [3,2,3,1,2,4,5,5,6], k = 4
    输出: 4
💡重点:
"""
class Solution:
    def partition(self, nums: List[int], left: int, right: int) -> int:
        """
        在子数组 [left, right] 中随机选择一个基准元素 pivot
        根据 pivot 重新排列子数组 [left, right]
        重新排列后,<= pivot 的元素都在 pivot 的左侧,>= pivot 的元素都在 pivot 的右侧
        返回 pivot 在重新排列后的 nums 中的下标
        特别地,如果子数组的所有元素都等于 pivot,我们会返回子数组的中心下标,避免退化
        """

        # 1. 在子数组 [left, right] 中随机选择一个基准元素 pivot
        i = random.randint(left, right)
        pivot = nums[i]
        # 把 pivot 与子数组第一个元素交换,避免 pivot 干扰后续划分,从而简化实现逻辑
        nums[i], nums[left] = nums[left], nums[i]

        # 2. 相向双指针遍历子数组 [left + 1, right]
        # 循环不变量:在循环过程中,子数组的数据分布始终如下图
        # [ pivot | <=pivot | 尚未遍历 | >=pivot ]
        #   ^                 ^     ^         ^
        #   left              i     j         right

        i, j = left + 1, right
        while True:
            while i <= j and nums[i] < pivot:
                i += 1
            # 此时 nums[i] >= pivot

            while i <= j and nums[j] > pivot:
                j -= 1
            # 此时 nums[j] <= pivot

            if i >= j:
                break

            # 维持循环不变量
            nums[i], nums[j] = nums[j], nums[i]
            i += 1
            j -= 1

        # 循环结束后
        # [ pivot | <=pivot | >=pivot ]
        #   ^             ^   ^     ^
        #   left          j   i     right

        # 3. 把 pivot 与 nums[j] 交换,完成划分(partition)
        # 为什么与 j 交换?
        # 如果与 i 交换,可能会出现 i = right + 1 的情况,已经下标越界了,无法交换
        # 另一个原因是如果 nums[i] > pivot,交换会导致一个大于 pivot 的数出现在子数组最左边,不是有效划分
        # 与 j 交换,即使 j = left,交换也不会出错
        nums[left], nums[j] = nums[j], nums[left]

        # 交换后
        # [ <=pivot | pivot | >=pivot ]
        #               ^
        #               j

        # 返回 pivot 的下标
        return j

    def findKthLargest(self, nums: list[int], k: int) -> int:
        n = len(nums)
        target_index = n - k  # 第 k 大元素在升序数组中的下标是 n - k
        left, right = 0, n - 1  # 闭区间
        while True:
            i = self.partition(nums, left, right)
            if i == target_index:
                # 找到第 k 大元素
                return nums[i]
            if i > target_index:
                # 第 k 大元素在 [left, i - 1] 中
                right = i - 1
            else:
                # 第 k 大元素在 [i + 1, right] 中
                left = i + 1

排序

算法复杂度

时间复杂度:

  • 快速排序:期望 $O(N\log(N))$
  • 归并排序:期望 $O(N\log(N))$

空间复杂度:

  • 快速排序:期望 $O(1)$
  • 归并排序:期望 $O(N)$

使用场景

  • 需要将数据排序后再处理(如二分查找、合并区间等)(90%)
  • 快速选择/Top K问题(快速排序的变种)(80%)
  • 逆序对数量统计(归并排序的变种)(90%)
  • 合并多个有序数组/链表 (归并排序思想)(80%)
  • 需要稳定排序时使用归并排序 (100%)

代码模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# 快速排序
def quick_sort(nums: List[int], left: int, right: int) -> None:
    """原地排序,平均时间 O(NlogN),最坏 O(N^2),空间 O(1)"""
    if left >= right:
        return
    # 随机选择 pivot 避免最坏情况
    pivot_idx = random.randint(left, right)
    nums[left], nums[pivot_idx] = nums[pivot_idx], nums[left]
    pivot = nums[left]

    i, j = left + 1, right
    while True:
        while i <= j and nums[i] < pivot:
            i += 1
        while i <= j and nums[j] > pivot:
            j -= 1
        if i >= j:
            break
        nums[i], nums[j] = nums[j], nums[i]
        i += 1
        j -= 1
    nums[left], nums[j] = nums[j], nums[left]

    quick_sort(nums, left, j - 1)
    quick_sort(nums, j + 1, right)

# 归并排序
def merge_sort(nums: List[int], left: int, right: int) -> None:
    """稳定排序,时间 O(NlogN),空间 O(N)"""
    if left >= right:
        return
    mid = left + (right - left) // 2
    merge_sort(nums, left, mid)
    merge_sort(nums, mid + 1, right)
    merge(nums, left, mid, right)

def merge(nums: List[int], left: int, mid: int, right: int) -> None:
    temp = []
    i, j = left, mid + 1
    while i <= mid and j <= right:
        if nums[i] <= nums[j]:
            temp.append(nums[i])
            i += 1
        else:
            temp.append(nums[j])
            j += 1
    while i <= mid:
        temp.append(nums[i])
        i += 1
    while j <= right:
        temp.append(nums[j])
        j += 1
    nums[left:right+1] = temp

# 堆排序
def heap_sort(nums: List[int]) -> None:
    """原地排序,时间 O(NlogN),空间 O(1)"""
    n = len(nums)
    # 建堆
    for i in range(n // 2 - 1, -1, -1):
        heapify(nums, n, i)
    # 逐个取出堆顶
    for i in range(n - 1, 0, -1):
        nums[0], nums[i] = nums[i], nums[0]
        heapify(nums, i, 0)

def heapify(nums: List[int], n: int, i: int) -> None:
    largest = i
    left, right = 2 * i + 1, 2 * i + 2
    if left < n and nums[left] > nums[largest]:
        largest = left
    if right < n and nums[right] > nums[largest]:
        largest = right
    if largest != i:
        nums[i], nums[largest] = nums[largest], nums[i]
        heapify(nums, n, largest)

例题

148. 排序链表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
📖描述:给你链表的头结点 `head`,请将其按升序排列并返回排序后的链表。
🧪样例:
    输入:head = [4,2,1,3]
    输出:[1,2,3,4]
💡重点:
    1. 要求 O(NlogN) 时间复杂度和常数级空间复杂度
    2. 使用归并排序:找到中点 -> 递归排序 -> 合并
"""
def sortList(head: ListNode) -> ListNode:
    if not head or not head.next:
        return head

    # 找到中点(快慢指针)
    slow, fast = head, head.next
    while fast and fast.next:
        slow = slow.next
        fast = fast.next.next

    # 断开
    mid = slow.next
    slow.next = None

    # 递归排序
    left = sortList(head)
    right = sortList(mid)

    # 合并两个有序链表
    dummy = ListNode(0)
    p = dummy
    while left and right:
        if left.val < right.val:
            p.next = left
            left = left.next
        else:
            p.next = right
            right = right.next
        p = p.next
    p.next = left if left else right
    return dummy.next

剑指 Offer 51. 数组中的逆序对

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
📖描述:在数组中的两个数字,如果前面一个数字大于后面的数字,则这两个数字组成一个逆序对。
    输入一个数组,求出这个数组中的逆序对的总数。
🧪样例:
    输入:[7,5,6,4]
    输出:5
    解释:逆序对 (7,5), (7,6), (7,4), (5,4), (6,4)
💡重点:
    1. 利用归并排序的过程统计逆序对
    2. 当左边元素 > 右边元素时,左边剩余元素都与右边当前元素构成逆序对
"""
def reversePairs(nums: List[int]) -> int:
    def merge_sort(nums, left, right):
        if left >= right:
            return 0
        mid = left + (right - left) // 2
        count = merge_sort(nums, left, mid) + merge_sort(nums, mid + 1, right)

        # 合并并统计逆序对
        temp = []
        i, j = left, mid + 1
        while i <= mid and j <= right:
            if nums[i] <= nums[j]:
                temp.append(nums[i])
                i += 1
            else:
                temp.append(nums[j])
                count += mid - i + 1  # 关键:左边剩余元素都与 nums[j] 构成逆序对
                j += 1
        while i <= mid:
            temp.append(nums[i])
            i += 1
        while j <= right:
            temp.append(nums[j])
            j += 1
        nums[left:right+1] = temp
        return count

    return merge_sort(nums, 0, len(nums) - 1)

动态规划

动态规划四要素:

  • 状态 (State) – 递归的定义
  • 方程 (Function) – 递归的拆解
  • 初始化 (Initialization) – 递归的出口
  • 答案 (Answer) – 递归的调用

常见的动态规划类型:

  • 背包型:给出 n 个物品及其大小,能否挑选出一些物品装满大小为 m 的背包
    • 通常用二维的状态数组dp[i][j],表示???
  • 区间型:题目中有 subarray / substring 的信息,通常大区间依赖小区间
    • dp[i][j] 表示数组/字符串中 i, j 这一段区间的最优值/可行性/方案总数
  • 匹配型:通常两个字符串的匹配值依赖于两个字符串前缀的匹配值
    • dp[i][j] 表示第一个字符串的前 i 个字符与第二个字符串的前 j 个字符的状态(max/min/sum/or)
  • 接龙型:给一个接龙规则,求最长的龙有多长
    • dp[i] 表示以坐标为 i 的元素结尾的最长龙的长度

算法复杂度

时间复杂度:O(状态总数 * 每个状态的处理耗费)

空间复杂度:O(状态总数)

使用场景

  • 求方案总数(90%)
  • 求最值(80%)
  • 求可行性(80%)

不适用的场景:

  • 找所有具体的方案(准确率 99%)
  • 输入数据无序(除了背包问题外,准确率 60%~70%)
  • 暴力算法已经是多项式时间复杂度(准确率 80%)

代码模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# 一维动态规划模板
def dp_1d(n):
    # 1. 定义状态数组
    dp = [0] * (n + 1)
    # 2. 初始化
    dp[0] = base_case
    # 3. 状态转移
    for i in range(1, n + 1):
        dp[i] = 状态转移方程
    # 4. 返回答案
    return dp[n]

# 二维动态规划模板
def dp_2d(m, n):
    # 1. 定义状态数组
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    # 2. 初始化
    for i in range(m + 1):
        dp[i][0] = ...
    for j in range(n + 1):
        dp[0][j] = ...
    # 3. 状态转移
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            dp[i][j] = 状态转移方程
    # 4. 返回答案
    return dp[m][n]

# 背包问题模板(0-1背包)
def knapsack(n, W, weights, values):
    # dp[i][j] 表示前i个物品,容量为j时的最大价值
    dp = [[0] * (W + 1) for _ in range(n + 1)]
    for i in range(1, n + 1):
        for j in range(1, W + 1):
            if j >= weights[i-1]:
                dp[i][j] = max(dp[i-1][j], dp[i-1][j-weights[i-1]] + values[i-1])
            else:
                dp[i][j] = dp[i-1][j]
    return dp[n][W]

# 空间优化版本(一维数组)
def knapsack_optimized(n, W, weights, values):
    dp = [0] * (W + 1)
    for i in range(n):
        for j in range(W, weights[i] - 1, -1):  # 逆序遍历
            dp[j] = max(dp[j], dp[j - weights[i]] + values[i])
    return dp[W]

例题

322. 零钱兑换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""
📖描述:给你一个整数数组 `coins`,表示不同面额的硬币;以及一个整数 `amount`,表示总金额。
    计算并返回可以凑成总金额所需的**最少**的硬币个数。如果没有任何一种硬币组合能组成总金额,返回 `-1`。
🧪样例:
    输入:coins = [1, 2, 5], amount = 11
    输出:3
    解释:11 = 5 + 5 + 1
💡重点:
    1. 完全背包问题,每种硬币可以使用多次
    2. dp[i] 表示凑成金额 i 需要的最少硬币数
"""
def coinChange(coins: List[int], amount: int) -> int:
    dp = [float('inf')] * (amount + 1)
    dp[0] = 0
    for i in range(1, amount + 1):
        for coin in coins:
            if i >= coin:
                dp[i] = min(dp[i], dp[i - coin] + 1)
    return dp[amount] if dp[amount] != float('inf') else -1

300. 最长递增子序列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""
📖描述:给你一个整数数组 `nums`,找到其中最长严格递增子序列的长度。
🧪样例:
    输入:nums = [10,9,2,5,3,7,101,18]
    输出:4
    解释:最长递增子序列是 [2,3,7,101],长度为 4
💡重点:
    1. dp[i] 表示以 nums[i] 结尾的最长递增子序列长度
    2. 时间 O(N^2),可以用二分优化到 O(NlogN)
"""
def lengthOfLIS(nums: List[int]) -> int:
    n = len(nums)
    dp = [1] * n
    for i in range(n):
        for j in range(i):
            if nums[i] > nums[j]:
                dp[i] = max(dp[i], dp[j] + 1)
    return max(dp)

# 二分优化版本 O(NlogN)
def lengthOfLIS(nums: List[int]) -> int:
    tails = []  # tails[i] 表示长度为 i+1 的子序列的最小末尾
    for num in nums:
        # 二分查找第一个 >= num 的位置
        left, right = 0, len(tails)
        while left < right:
            mid = left + (right - left) // 2
            if tails[mid] < num:
                left = mid + 1
            else:
                right = mid
        if left == len(tails):
            tails.append(num)
        else:
            tails[left] = num
    return len(tails)

1143. 最长公共子序列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
📖描述:给定两个字符串 `text1` 和 `text2`,返回这两个字符串的最长公共子序列的长度。
🧪样例:
    输入:text1 = "abcde", text2 = "ace"
    输出:3
    解释:最长公共子序列是 "ace"
💡重点:
    1. dp[i][j] 表示 text1 前 i 个字符和 text2 前 j 个字符的 LCS 长度
    2. 匹配型动态规划
"""
def longestCommonSubsequence(text1: str, text2: str) -> int:
    m, n = len(text1), len(text2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if text1[i-1] == text2[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    return dp[m][n]

贪心算法

贪心算法是一种在每一步选择中都采取在当前状态下最好/最优的选择,从而希望导致结果是全局最好/最优的算法。

使用场景

  • 区间调度问题(按结束时间排序)(90%)
  • 跳跃游戏类问题 (80%)
  • 分发糖果/分配问题 (70%)
  • 股票买卖问题(只能买卖一次或多次)(80%)
  • Huffman编码 (100%)

例题

55. 跳跃游戏

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
📖描述:给你一个非负整数数组 `nums`,你最初位于数组的第一个下标。数组中的每个元素代表你在该位置可以跳跃的最大长度。
    判断你是否能够到达最后一个下标。
🧪样例:
    输入:nums = [2,3,1,1,4]
    输出:true
    解释:可以先跳 1 步,从下标 0 到达下标 1,然后再从下标 1 跳 3 步到达最后一个下标。
💡重点:
    1. 贪心维护最远可达位置
    2. 如果当前位置超过了最远可达位置,则无法继续
"""
def canJump(nums: List[int]) -> bool:
    max_reach = 0
    for i, jump in enumerate(nums):
        if i > max_reach:  # 当前位置不可达
            return False
        max_reach = max(max_reach, i + jump)
        if max_reach >= len(nums) - 1:
            return True
    return True

435. 无重叠区间

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""
📖描述:给定一个区间的集合,找到需要移除区间的最小数量,使剩余区间互不重叠。
🧪样例:
    输入:intervals = [[1,2],[2,3],[3,4],[1,3]]
    输出:1
    解释:移除 [1,3] 后,剩下的区间没有重叠。
💡重点:
    1. 按结束时间排序,贪心选择结束最早的区间
    2. 等价于求最多能保留多少个不重叠区间
"""
def eraseOverlapIntervals(intervals: List[List[int]]) -> int:
    if not intervals:
        return 0
    intervals.sort(key=lambda x: x[1])  # 按结束时间排序
    count = 1  # 保留的区间数
    end = intervals[0][1]
    for i in range(1, len(intervals)):
        if intervals[i][0] >= end:  # 不重叠
            count += 1
            end = intervals[i][1]
    return len(intervals) - count

455. 分发饼干

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""
📖描述:假设你是一位很棒的家长,想要给你的孩子们一些小饼干。但是,每个孩子最多只能给一块饼干。
    对每个孩子 `i`,都有一个胃口值 `g[i]`,每块饼干 `j`,都有一个尺寸 `s[j]`。
    只有当 `s[j] >= g[i]` 时,我们才可以将这个饼干 `j` 分配给孩子 `i`。
    目标是尽可能满足越多数量的孩子。
🧪样例:
    输入:g = [1,2,3], s = [1,1]
    输出:1
    解释:你有三个孩子和两块小饼干,3 个孩子的胃口值分别是:1, 2, 3。虽然你有两块小饼干,但只能满足胃口值为 1 的孩子。
💡重点:
    1. 排序后双指针贪心匹配
    2. 小饼干优先满足小胃口的孩子
"""
def findContentChildren(g: List[int], s: List[int]) -> int:
    g.sort()
    s.sort()
    i, j = 0, 0
    while i < len(g) and j < len(s):
        if s[j] >= g[i]:
            i += 1  # 满足一个孩子
        j += 1  # 尝试下一块饼干
    return i

宽度优先搜索 BFS

算法复杂度

时间复杂度:$O(n + m)$, n 是点数, m 是边数

空间复杂度:$O(n)$

使用场景

  • 拓扑排序(100%)
  • 出现连通块的关键词(100%)
  • 分层遍历(100%)
  • 简单图最短路径(100%)
  • 给定一个变换规则,从初始状态变到终止状态最少几步(100%)

代码模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from collections import deque

# 基本BFS模板
def bfs(start, target):
    queue = deque([start])
    visited = set([start])
    step = 0

    while queue:
        size = len(queue)
        for _ in range(size):  # 分层遍历
            cur = queue.popleft()
            if cur == target:
                return step
            for next_node in get_neighbors(cur):
                if next_node not in visited:
                    visited.add(next_node)
                    queue.append(next_node)
        step += 1
    return -1

# 网格BFS模板
def bfs_grid(grid, start, end):
    m, n = len(grid), len(grid[0])
    queue = deque([(start[0], start[1])])
    visited = set([(start[0], start[1])])
    directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
    step = 0

    while queue:
        size = len(queue)
        for _ in range(size):
            x, y = queue.popleft()
            if (x, y) == end:
                return step
            for dx, dy in directions:
                nx, ny = x + dx, y + dy
                if 0 <= nx < m and 0 <= ny < n and (nx, ny) not in visited and grid[nx][ny] != '#':
                    visited.add((nx, ny))
                    queue.append((nx, ny))
        step += 1
    return -1

例题

102. 二叉树的层序遍历

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""
📖描述:给你二叉树的根节点 `root`,返回其节点值的层序遍历(即逐层地,从左到右访问所有节点)。
🧪样例:
    输入:root = [3,9,20,null,null,15,7]
    输出:[[3],[9,20],[15,7]]
💡重点:
    1. 使用队列进行 BFS
    2. 需要记录每层的节点数量
"""
def levelOrder(root: TreeNode) -> List[List[int]]:
    if not root:
        return []
    from collections import deque
    queue = deque([root])
    result = []
    while queue:
        level = []
        size = len(queue)
        for _ in range(size):
            node = queue.popleft()
            level.append(node.val)
            if node.left:
                queue.append(node.left)
            if node.right:
                queue.append(node.right)
        result.append(level)
    return result

127. 单词接龙

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""
📖描述:给定两个单词 `beginWord` 和 `endWord`,以及一个字典 `wordList`,
    找到从 `beginWord` 到 `endWord` 的最短转换序列的长度。
    转换规则:每次只能改变一个字母。
🧪样例:
    输入:beginWord = "hit", endWord = "cog", wordList = ["hot","dot","dog","lot","log","cog"]
    输出:5
    解释:"hit" -> "hot" -> "dot" -> "dog" -> "cog"
💡重点:
    1. BFS求最短路径
    2. 双向BFS可以优化效率
"""
def ladderLength(beginWord: str, endWord: str, wordList: List[str]) -> int:
    if endWord not in wordList:
        return 0

    word_set = set(wordList)
    from collections import deque
    queue = deque([(beginWord, 1)])
    visited = set([beginWord])

    while queue:
        word, level = queue.popleft()
        if word == endWord:
            return level
        # 尝试所有可能的变换
        for i in range(len(word)):
            for c in 'abcdefghijklmnopqrstuvwxyz':
                new_word = word[:i] + c + word[i+1:]
                if new_word in word_set and new_word not in visited:
                    visited.add(new_word)
                    queue.append((new_word, level + 1))
    return 0

200. 岛屿数量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
📖描述:给你一个由 '1'(陆地)和 '0'(水)组成的的二维网格,请你计算网格中岛屿的数量。
🧪样例:
    输入:grid = [
      ["1","1","1","1","0"],
      ["1","1","0","1","0"],
      ["1","1","0","0","0"],
      ["0","0","0","0","0"]
    ]
    输出:1
💡重点:
    1. 遍历网格,遇到 '1' 就 BFS/DFS 标记整个岛屿
    2. 标记过的格子改为 '0' 避免重复访问
"""
def numIslands(grid: List[List[str]]) -> int:
    if not grid:
        return 0
    m, n = len(grid), len(grid[0])
    count = 0

    def bfs(i, j):
        from collections import deque
        queue = deque([(i, j)])
        grid[i][j] = '0'
        directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]
        while queue:
            x, y = queue.popleft()
            for dx, dy in directions:
                nx, ny = x + dx, y + dy
                if 0 <= nx < m and 0 <= ny < n and grid[nx][ny] == '1':
                    grid[nx][ny] = '0'
                    queue.append((nx, ny))

    for i in range(m):
        for j in range(n):
            if grid[i][j] == '1':
                bfs(i, j)
                count += 1
    return count

深度优先搜索 DFS

DFS使用递归或栈实现,沿着一条路径走到底再回溯,常用于遍历所有方案。

算法复杂度

时间复杂度:O(方案个数 * 构造每个方案的时间)

  • 树的遍历: $O(N)$
  • 排列问题: $O(N! * N)$
  • 组合问题: $O(2^N * N)$

使用场景

  • 找满足某个条件的所有方案 (99%)
  • 二叉树 Binary Tree 的问题 (90%)
  • 组合问题(95%)
    • 问题模型:求出所有满足条件的”组合”
    • 判断条件:组合中的元素是顺序无关的
  • 排列问题 (95%)
    • 问题模型:求出所有满足条件的”排列”
    • 判断条件:组合中的元素是顺序”相关”的。

不要用 DFS 的场景:

  • 连通块问题(一定要用 BFS,否则 StackOverflow)
  • 拓扑排序(一定要用 BFS,否则 StackOverflow)
  • 一切 BFS 可以解决的问题

代码模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# 基本DFS模板(回溯)
result = []
def dfs(参数列表):
    if 递归出口:
        result.append(当前方案)
        return
    for 所有的拆解可能性:
        做选择
        dfs(参数列表)
        撤销选择

# 组合问题模板
def combine(n: int, k: int) -> List[List[int]]:
    result = []
    def backtrack(start, path):
        if len(path) == k:
            result.append(path[:])
            return
        for i in range(start, n + 1):
            path.append(i)
            backtrack(i + 1, path)
            path.pop()
    backtrack(1, [])
    return result

# 排列问题模板
def permute(nums: List[int]) -> List[List[int]]:
    result = []
    def backtrack(path, used):
        if len(path) == len(nums):
            result.append(path[:])
            return
        for i in range(len(nums)):
            if used[i]:
                continue
            used[i] = True
            path.append(nums[i])
            backtrack(path, used)
            path.pop()
            used[i] = False
    backtrack([], [False] * len(nums))
    return result

例题

46. 全排列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""
📖描述:给定一个不含重复数字的数组 `nums`,返回其所有可能的全排列。
🧪样例:
    输入:nums = [1,2,3]
    输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]
💡重点:
    1. 排列问题,需要使用 used 数组记录已使用的元素
    2. 回溯时记得撤销选择
"""
def permute(nums: List[int]) -> List[List[int]]:
    result = []

    def backtrack(path, used):
        if len(path) == len(nums):
            result.append(path[:])
            return
        for i in range(len(nums)):
            if used[i]:
                continue
            used[i] = True
            path.append(nums[i])
            backtrack(path, used)
            path.pop()
            used[i] = False

    backtrack([], [False] * len(nums))
    return result

78. 子集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""
📖描述:给你一个整数数组 `nums`,数组中的元素互不相同。返回该数组所有可能的子集(幂集)。
🧪样例:
    输入:nums = [1,2,3]
    输出:[[],[1],[2],[1,2],[3],[1,3],[2,3],[1,2,3]]
💡重点:
    1. 组合问题,每个元素选或不选
    2. 每个节点都是答案的一部分
"""
def subsets(nums: List[int]) -> List[List[int]]:
    result = []

    def backtrack(start, path):
        result.append(path[:])  # 每个节点都是一个子集
        for i in range(start, len(nums)):
            path.append(nums[i])
            backtrack(i + 1, path)
            path.pop()

    backtrack(0, [])
    return result

22. 括号生成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
"""
📖描述:数字 `n` 代表生成括号的对数,请你设计一个函数,用于能够生成所有可能的并且有效的括号组合。
🧪样例:
    输入:n = 3
    输出:["((()))","(()())","(())()","()(())","()()()"]
💡重点:
    1. 左括号数量必须小于 n,右括号数量必须小于左括号数量
    2. 剪枝优化:不满足条件提前终止
"""
def generateParenthesis(n: int) -> List[str]:
    result = []

    def backtrack(s, left, right):
        if len(s) == 2 * n:
            result.append(s)
            return
        if left < n:
            backtrack(s + '(', left + 1, right)
        if right < left:
            backtrack(s + ')', left, right + 1)

    backtrack('', 0, 0)
    return result

79. 单词搜索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""
📖描述:给定一个 `m x n` 二维字符网格 `board` 和一个字符串单词 `word`。
    如果 `word` 存在于网格中,返回 `true`;否则,返回 `false`。
🧪样例:
    输入:board = [["A","B","C","E"],["S","F","C","S"],["A","D","E","E"]], word = "ABCCED"
    输出:true
💡重点:
    1. DFS + 回溯,注意标记已访问的格子
    2. 找到一条路径即可返回 true
"""
def exist(board: List[List[str]], word: str) -> bool:
    m, n = len(board), len(board[0])

    def dfs(i, j, k):
        if k == len(word):
            return True
        if i < 0 or i >= m or j < 0 or j >= n or board[i][j] != word[k]:
            return False

        temp = board[i][j]
        board[i][j] = '#'  # 标记已访问
        found = (dfs(i+1, j, k+1) or dfs(i-1, j, k+1) or
                 dfs(i, j+1, k+1) or dfs(i, j-1, k+1))
        board[i][j] = temp  # 恢复
        return found

    for i in range(m):
        for j in range(n):
            if dfs(i, j, 0):
                return True
    return False

参考资源

This post is licensed under CC BY 4.0 by the author.