STL sort函数的用法

时间:2024-08-18 09:38:14

sort 在 STL 库中是排序函数,有时冒泡、选择等 $\mathcal O(n^2)$ 算法会超时时,我们可以使用 STL 中的快速排序函数 $\mathcal O(n \ log \ n)$ 完成排序

sort 在 algorithm 库里面,原型如下:

template <class RandomAccessIterator>
void sort ( RandomAccessIterator first, RandomAccessIterator last );
template <class RandomAccessIterator, class Compare>
void sort ( RandomAccessIterator first, RandomAccessIterator last, Compare comp );

我们会发现 sort 有两种形式一个有三个参数,一个有两个参数,我们先讲讲两个参数的吧!

sort 的前两个参数是起始地址和中止地址

如:sort(a,a+n) 表示对 a[0] ... a[n-1] 排序

代码如下:

#include <algorithm>
#include <cstdio>
using namespace std;
int main() {
int n,a[1001];
scanf("%d",&n);
for (int i = 1;i <= n;i++) scanf("%d",&a[i]);
sort(a+1,a+n+1); //对a[1] ... a[n] 排序
for (int i = 1;i <= n;i++) printf("%d",a[i]);
return 0'
}

这样是默认升序的,那如果是降序呢?

这样,我们就要用到第三个参数,第三个参数是一个比较函数

bool cmp(int a,int b) { return a > b; }

这个就是降序排序的比较函数,意思是:

是 a > b 时为true,就不交换,a < b 时为 false,交换

然后我们调用 sort(a+1,a+n+1,cmp) 就可以对 a 数组进行排序了

还可以调用 greater 和 less 进行升/降序排序,其实就是一个帮你写好的函数

int a[11],n;
scanf("%d",&n);
for (int i = 1;i <= n;i++) scanf("%d",&a[i]);
sort(a+1,a+n+1,greater<int>()); //升序
sort(a+1,a+n+1,less<int>()); //降序,注意尖括号内写的是排序的数组类型

sort 也能对结构体排序,如:

#include <algorithm>
#include <cstdio>
using namespace std;
struct Node {
int x,y;
} p[1001];
int n;
bool cmp(Node a,Node b) {
if (a.x != b.x) return a.x < b.x;
return a.y < b.y;
}
int main() {
scanf("%d",&n);
for (int i = 1;i <= n;i++) scanf("%d%d",&p[i].x,&p[i].y);
sort(p+1,p+n+1,cmp);
for (int i = 1;i <= n;i++) scanf("%d %d\n",p[i].x,p[i].y);
return 0;
}

以上代码的意思是,对 p 数组按 x 升序排序,若两个数的 x 相等则按 y 升序排序

结构体还可以重载运算符(greater 和 less 都是重载运算符的),使 sort 只用两个参数就可以按自己的规则排序,如:

#include <algorithm>
#include <cstdio>
using namespace std;
struct Node {
int x,y;
bool operator < (Node cmp) const {
if (a.x != cmp.x) return a.x < cmp.x;
return a.y < cmp.y;
}
}p[1001];
int n;
/*bool cmp(Node a,Node b) {
* if (a.x != b.x) return a.x < b.x;
* return a.y < b.y;
*}
*/
int main() {
scanf("%d",&n);
for (int i = 1;i <= n;i++) scanf("%d%d",&p[i].x,&p[i].y);
sort(p+1,p+n+1);
for (int i = 1;i <= n;i++) scanf("%d %d\n",p[i].x,p[i].y);
return 0;
}