STL 源码分析《2》----nth_element() 使用与源码分析

时间:2024-11-28 22:07:13

Select 问题: 在一个无序的数组中 找到第 n 大的元素。

思路 1: 排序,O(NlgN)

思路 2: 利用快排的 RandomizedPartition(), 平均复杂度是 O(N)

思路 3:    同样是利用快排的 Partition(), 但是选择 pivot 的时候不是采用随机,而是通过一种特殊的方法。从而使复杂度最坏情况下是 O(N)。

本文介绍 STL 算法库中 nth_elemnt 的实现代码。

STL 采用的算法是: 当数组长度 <= 3时, 采用插入排序。

当长度 > 3时, 采用快排 Partition 的思想;

一、使用说明

void

nth_element (RandomAccessIteratorbeg,

RandomAccessIterator
nth
,

RandomAccessIterator
end
)

void

nth_element (RandomAccessIterator beg,

RandomAccessIterator nth,

RandomAccessIterator end,

                         BinaryPredicate op)

1. 两个函数都是让 第 n 个位置上的元素就位,

所有在位置 n 之前的元素都小于或等于它,

所有在位置 n 之后的元素都大于或等于它。

2. 复杂度: 平均复杂度是 O(N)

以下例子是使用范例:

// copyright @ L.J.SHOU Feb.23, 2014
#include <iostream>
#include <algorithm>
#include <iterator>
using namespace std; int main(void)
{
int a[]={3,5,2,6,1,4}; nth_element(a, a+3, a+sizeof(a)/sizeof(int));
cout << "The fourth element is: " << a[3] << endl; // output array a[]
copy(a, a+sizeof(a)/sizeof(int),
ostream_iterator<int>(cout, " "));
return 0;
}

程序输出结果:

The fourth element is: 4

     2 1 3 4 6 5

二、源码分析

// nth_element() and its auxiliary functions.  

template <class _RandomAccessIter, class _Tp>
void __nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
_RandomAccessIter __last, _Tp*) {
while (__last - __first > 3) {
_RandomAccessIter __cut =
__unguarded_partition(__first, __last,
_Tp(__median(*__first,
*(__first + (__last - __first)/2),
*(__last - 1))));
if (__cut <= __nth)
__first = __cut;
else
__last = __cut;
}
__insertion_sort(__first, __last);
} template <class _RandomAccessIter>
inline void nth_element(_RandomAccessIter __first, _RandomAccessIter __nth,
_RandomAccessIter __last) {
__STL_REQUIRES(_RandomAccessIter, _Mutable_RandomAccessIterator);
__STL_REQUIRES(typename iterator_traits<_RandomAccessIter>::value_type,
_LessThanComparable);
__nth_element(__first, __nth, __last, __VALUE_TYPE(__first));
}
template <class _RandomAccessIter, class _Tp>
_RandomAccessIter __unguarded_partition(_RandomAccessIter __first, 
                                        _RandomAccessIter __last, 
                                        _Tp __pivot) 
{
  while (true) {
    while (*__first < __pivot)
      ++__first;
    --__last;
    while (__pivot < *__last)
      --__last;
    if (!(__first < __last))
      return __first;
    iter_swap(__first, __last);
    ++__first;
  }

_unguarded_partition 就是快排的 partition, 将数组分成两部分,左边的元素都小于或者等于 pivot, 右边的元素都大于或者等于 pivot.

从上述代码可以看出, nth_element 采用的 pivot 是 首元素,尾元素,中间元素,三个数的median.

通过_unguarded_partition 将数组分成两部分,

如果 nth 这个迭代器在左半边,则继续在左半边搜索;

若   nth 在右半边, 则在右半边搜索;

直到数组的长度 <= 3,时, 采用插入排序。这时 nth 迭代器所指向的数就归位了,而且它的左边元素都小于或者等于它, 右边元素都大于或者等于它。