分治算法之合并排序

时间:2022-03-06 11:04:03

分治算法的基本思想是将一个规模为n的问题分解成k个规模较小的子问题,这些子问题相互独立并且与原问题相同。先递归的解决这些子问题,然后再将各个子问题的解合并到原问题的解当中。

合并排序算法是用分治策略实现对n个元素进行排序的算法。其基本思想是将待排序元素分成大小大致相同的两个子集合,分别对两个子集合进行排序,最终将排好序的两个子集合合并成一个排好序的集合。合并排序算法可递归的伪代码表达如下:

void mergeSort(int *a, int left, int right)
{
    int i;
    if (left < right) {
        i = (left + right) / 2;
        mergeSort(a, left, i);
        mergeSort(a, i + 1, right);
        merge(a, b, left, i, right);
        copy(a, b, left, right);
    }
}

上述的mergeSort函数中,递归调用是整个算法的关键步骤。mergeSort函数不断的平分待排序的数列集合,如果当前数列集合两端的序号分别为left和right,那么平分后的两个数列集合序号分别是left到i和i+1到right。这种不断平分并递归的结果是使得当前的待排序集合中只剩下一个元素,接着再进行两两子序列集合的合并。在mergeSort函数内,不断的进行递归调用以缩小问题规模;merge函数用于将两个排好序的子序列合并成一个大的有序序列,并存储在数组b中;最后利用copy函数将b数组的序列再重新拷贝到数组a中,完成合并排序。

正如上面所说的,这些待合并的子序列都已排好序。并且最初一批待合并的子序列集合中只有一个元素,合并后元素数量为2,再次合并为4,依次类推。

根据上述的描述,我们可以将递归形式的合并排序算法改进成非递归的形式。在上述递归形式中,我们是从整个序列出发,逐渐平分再递归。而非递归形式的排序算法则先让整个序列中相邻的元素两两进行排序,形成n/2个长度为2的已排好序的子序列。接着再将它们排成长度为4的子序列,以此类推。该算法结束结束时是将两个已经排好序的子序列排成一个有序序列。

根据上面的思路,非递归形式的合并算法可参考如下代码。mergeSort函数依次对整个待排序的序列中长度为1、2、4、8的子序列进行排序。s即为当前正进行排序的子序列集合的元素个数。

void mergeSort(int a[], int n)
{
    int *b = NULL;
    int s = 1;
    int count = 1;
    b = (int *)malloc(sizeof(int) * n);
    while (s < n) {
        printf("sort %d:\n", count++);
        mergePass(a, b, s, n);
        s += s;
        printf("sort %d:\n", count++);
        mergePass(b, a, s, n);
        s += s;
    }
    free(b);
}

mergePass函数的作用是大小为s的相邻子序列。通过while循环将整个序列分成n/s个大小为s的子序列,由于这写子序列内部已经排好序,则调用merge函数直接进行合并即可。

void mergePass(int x[], int y[], int s, int n)
{
    int i = 0;
    int j;
    while (i < n - 2 * s) {
        merge(x, y, i, i + s - 1, i + 2 * s -1);
        i = i + 2 * s;
    }
    if (i + s < n)
        merge(x, y, i, i + s - 1, n - 1);
    else
        for (j = i; j <= n - 1; j++)
            y[j] = x[j];
    for (i = 0; i < n; i++)
        printf("%d ", y[i]);
    printf("\n");
}

merge函数的作用是合并两个相邻的子序列,这两个子序列的序号分别为l到m和m+1到r。
void merge(int c[], int d[], int l, int m, int r)
{
    int i, j, k;
    i = l;
    j = m + 1;
    k = l;
    while ((i <= m) && (j <= r))
        if (c[i] <= c[j])
            d[k++] = c[i++];
        else
            d[k++] = c[j++];
    int q;
    if (i > m)
        for (q = j; q <= r; q++)
            d[k++] = c[q];
    else
        for (q = i; q <= m; q++)
            d[k++] = c[q];
}