归并排序的优化

归并排序的优化

参考文章:

第 2.5 节 归并排序的优化

重写归并排序

在之前「归并排序」的写法中,一直依靠「python中列表分割」,所以可以很快并很简单的实现归并排序

def merge_sort(arr):
    mid = len(arr) >> 1
    if len(arr) <= 1:
        return arr
    # 左边数组
    left = merge_sort(arr[:mid])
    # 右边数组
    right = merge_sort(arr[mid:])
    res = []
    cur_l, cur_r = 0, 0
    while cur_l < len(left) and cur_r < len(right):
        if left[cur_l] <= right[cur_r]:
            res.append(left[cur_l])
            cur_l += 1
        else:
            res.append(right[cur_r])
            cur_r += 1
    res += left[cur_l:]
    res += right[cur_r:]
    return res

上方的写法实实在在的进行了数组分割并重建,可以使我们快速的理解「归并排序」,并且在保持不改动原数组的情况下,进行归并排序

缺点就是,浪费了许多空间,每一次递归都新创建了 leftright 两个数组

接下来,重写「归并排序」,在原数组上进行归并排序

  • 写出最顶层函数

因为是在原数组上进行“逻辑分割”, 所以需要传入索引值

left: 数组的起始元素索引

right:数组的结束元素索引

def main():
    nums = [1,3,7,2,0]
  	length = len(nums)
    merge_sort(nums, 0, length - 1)
  • 写递归函数

不同于之前的写法,这里我们将写两个函数

merge_sort: 归并排序的主函数,主要用于“分割”数组

merge_two_sorted_array: 合并两个有序数组

def merge_sort(arr, left, right):
    if left >= right:
        return 
    mid = (left + right) >> 1
    merge_sort(arr, left, mid)
    merge_sort(arr, mid + 1, right)
    merge_two_sorted_array(arr, left, mid, right)

注意1:

if left >= right:
    return 

递归的结束条件,举例说明

[1,2] 这个数组进行归并

left = 0, right = 1

mid计算后得 0

进入下一层递归merge_sort(arr, left, mid)

此时,left = 0 , mid = 0,表示只有一个元素不可再分,所以应该直接返回

注意2:

merge_sort(arr, left, mid) ,逻辑分割原数组

注意3:

merge_sort(arr, left, mid)
merge_sort(arr, mid + 1, right)
merge_two_sorted_array(arr, left, mid, right)

在执行完左右两次 merge_sort之后,[left, mid] 和 [mid+1, right] 分别有序,所以调用 merge_two_sorted_array 将[left, right]区间内的元素进行排序

  • 编写merge_two_sorted_array方法
def merge_two_sorted_array(arr, left, mid, right):
    # 左数组 arr[left:mid]
    # 右数组 arr[mid:right]
    # 将这个需要合并的区间复制一份出来,用于比较
    temp = arr[left:right + 1]
    # p1, p2 分别是左数组和右数组的起始位置索引值
    p1, p2 = 0, mid - left + 1
    length = len(temp)
    for i in range(length):
        if p1 > mid - left:
            # 左边已走完
            arr[i + left] = temp[p2]
            p2 += 1
        elif p2 > length - 1:
            # 右边已走完
            arr[i + left] = temp[p1]
            p1 += 1
        elif temp[p1] <= temp[p2]:
            arr[i + left] = temp[p1]
            p1 += 1
        elif temp[p1] > temp[p2]:
            arr[i + left] = temp[p2]
            p2 += 1

一点一点解释上面merge_two_sorted_array的代码

注意1:temp = arr[left:right + 1]

将原数组[left, right]区间复制出来,其重要就是方便比较。用temp数组比较,然后将正确的值赋值到原数组对应的位置

注意2:p1, p2 = 0, mid - left + 1

p1, p2 是两个有序数组的起始位置索引,p1 和 p2的索引是对于 temp 来说的

再解释下为什么p1 p2 会是 0 , mid - left + 1

首先在原数组中有

左数组 arr[left:mid], 区间为 [left, mid - 1]

右数组 arr[mid:right], 区间为 [mid, right]

所以,此时左数组的起始元素位置索引是 left, 而右数组的是 mid

因为 此区间被赋值到 临时数组temp上,所以对于临时数组而言:

左数组的起始元素位置索引是 left - left ,即 0

右数组的起始元素位置索引是 mid - left + 1

循环内的代码

    for i in range(length):
        if p1 > mid - left:
            # 左边已走完
            arr[i + left] = temp[p2]
            p2 += 1
        elif p2 > length - 1:
            # 右边已走完
            arr[i + left] = temp[p1]
            p1 += 1
        elif temp[p1] <= temp[p2]:
            arr[i + left] = temp[p1]
            p1 += 1
        elif temp[p1] > temp[p2]:
            arr[i + left] = temp[p2]
            p2 += 1

前两个判断: 判断左右两数组中,其中一个走完后,所应该做的操作

后两个判断:判断大小关系

完整代码:

def merge_two_sorted_array(arr, left, mid, right):
    # arr[left:mid]
    # arr[mid:right]
    p1, p2 = 0, mid - left + 1
    # 将这个需要合并的区间复制一份出来,用于比较
    temp = arr[left:right + 1]
    length = len(temp)
    for i in range(length):
        if p1 > mid - left:
            # 左边已走完
            arr[i + left] = temp[p2]
            p2 += 1
        elif p2 > length - 1:
            # 右边已走完
            arr[i + left] = temp[p1]
            p1 += 1
        elif temp[p1] <= temp[p2]:
            arr[i + left] = temp[p1]
            p1 += 1
        elif temp[p1] > temp[p2]:
            arr[i + left] = temp[p2]
            p2 += 1


def merge_sort(arr, left, right):
    if left >= right:
        return
    mid = (left + right) >> 1
    merge_sort(arr, left, mid)
    merge_sort(arr, mid + 1, right)
    merge_two_sorted_array(arr, left, mid, right)


def main():
	arr = [1,0,6]
    merge_sort(arr, 0, len(arr)-1)
    # merge_sort(arr)
    print(arr)

三个优化方向

  • 如果两个数组,直接拼起来就是有序的,就无须 merge。即当 arr[mid] <= arr[mid + 1] 的时候是不用 merge 的;
  • 因为「插入排序」在小规模的排序任务上表现出色,所以我们可以在小区间里使用「插入排序」;
  • 我们每次做归并的时候,都 new 了辅助的空间,用完之后就丢弃了。事实上,我们可以全程使用 1 个和待排序数组一样长度的数组作为辅助归并两个排序数组的临时空间,这样就避免了频繁 new 和 delete 数组空间的操作。

优化1:如果两个数组有序则无需归并

def merge_sort(arr, left, right):
    if left >= right:
        return
    mid = left + (right - left) // 2
    merge_sort(arr, left, mid)
    merge_sort(arr, mid + 1, right)
    # 优化1
    if arr[mid] <= arr[mid + 1]:
        return
    merge_of_two_sorted_array(arr, left, mid, right)

优化2:小区间使用「插入排序」

定义一个新函数 insert_sort_for_merge 在原数组[left, right]区间内通过「插入排序」变得有序

def insert_sort_for_merge_1(arr, left, right):
    """
    逐个向前交换的插入排序
    """
    # n = right - left + 1
    for i in range(left + 1, right + 1):
        for j in range(i, left, -1):  # 这里是 left
            if arr[j - 1] > arr[j]:
                arr[j], arr[j - 1] = arr[j - 1], arr[j]
            else:
                break

关于「插入排序」还有一种「多次赋值」法更加高效

def insert_sort_for_merge_2(arr, left, right):
    """
    多次赋值
    """
    for i in range(left + 1, right + 1):
        # 临时保存
        temp = arr[i]
        # 表示当前元素的前一个位置
        p = i - 1
        while p >= left and arr[p] > temp:
            # 将p所在位置元素向后移动一位,留出插入位置
            arr[p+1] = arr[p]
            # ”指针“ 左移
            p -= 1
        # 退出循环时,空出了一个位置,这个位置可以将 我们之前临时保存的值插入
        arr[p + 1] = temp

「插入排序」在小规模的排序任务上表现出色,所以我们这里让长度小于 15 的区间使用插入排序

def merge_sort(arr, left, right):
    # 优化2
    if right - left <= 15:
        insert_sort_for_merge_2(arr, left, right)
        return
    
    mid = left + (right - left) // 2  # 这是一个陷阱
    merge_sort(arr, left, mid)
    merge_sort(arr, mid + 1, right)
    # 优化1
    if arr[mid] <= arr[mid + 1]:
        return
    merge_of_two_sorted_array(arr, left, mid, right)

优化3:全局使用一个临时数组用于归并

def main(arr):
    # 优化3: 定义一个同等长度的全局辅助数组
    global nums_for_compare
    nums_for_compare = list(range(len(nums)))
    merge_sort(nums, 0, len(nums) - 1)

之前我们在 merge_of_two_sorted_array 函数中,有不断的新建 temp数组用于比较,此时我们可以替换成这个全局数组 nums_for_compare,同样使用left 和 right 限制其区间

def merge_two_sorted_array(arr, left, mid, right):
    # 因为nums_for_compare 默认是空的,所以我们需要在对应索引位置赋值
    for i in range(left, right + 1):
        nums_for_compare[i] = arr[i]
    p1 = left # 左边数组起始位置索引
    p2 = mid + 1 # 右边数组起始位置索引
    for k in range(left, right + 1):
        if p1 >= mid + 1:
            # 左边走完了
            arr[k] = nums_for_compare[p2]
            p2 += 1
        elif p2 >= lengh - 1:
            # 右边走完了
            arr[k] = nums_for_compare[p1]
            p1 += 1
       	elif nums_for_compare[p1] <= nums_for_compare[p2]:
            arr[k] = nums_for_compare[p1]
            p1 += 1
		else:
            arr[k] = nums_for_compare[p2]
            p2 += 1

完整代码

def insert_sort_for_merge_2(arr, left, right):
    """
    多次赋值
    """
    for i in range(left + 1, right + 1):
        # 临时保存
        temp = arr[i]
        p = i - 1
        while p >= left and arr[p] > temp:
            # if arr[p] > temp:
            arr[p + 1] = arr[p]
            # ”指针“ 左移
            p -= 1
        # 退出循环时,空出了一个位置,这个位置可以将 我们之前临时保存的值插入
        arr[p + 1] = temp


def merge_two_sorted_array(arr, left, mid, right):
    # 因为nums_for_compare 默认是空的,所以我们需要在对应索引位置赋值
    for i in range(left, right + 1):
        nums_for_compare[i] = arr[i]
    p1 = left  # 左边数组起始位置索引
    p2 = mid + 1  # 右边数组起始位置索引
    for k in range(left, right + 1):
        if p1 >= mid + 1:
            # 左边走完了
            arr[k] = nums_for_compare[p2]
            p2 += 1
        elif p2 > right:
            # 右边走完了
            arr[k] = nums_for_compare[p1]
            p1 += 1
        elif nums_for_compare[p1] < nums_for_compare[p2]:
            arr[k] = nums_for_compare[p1]
            p1 += 1
        else:
            arr[k] = nums_for_compare[p2]
            p2 += 1


def merge_sort(arr, left, right):
    # 优化2
    if right - left <= 15:
        insert_sort_for_merge_2(arr, left, right)
        return

    mid = left + (right - left) // 2  # 这是一个陷阱
    merge_sort(arr, left, mid)
    merge_sort(arr, mid + 1, right)
    # 优化1
    if arr[mid] <= arr[mid + 1]:
        return
    merge_two_sorted_array(arr, left, mid, right)


def main():
    # 优化3
    global nums_for_compare
    arr = [1,2,3,4,0]
    nums_for_compare = list(range(len(arr)))
    merge_sort(arr, 0, len(arr) - 1)


if __name__ == '__main__':
    main()