数据结构 ——— 用堆解决TOP-K问题

时间:2024-11-02 14:02:45

目录

何为TOP-K问题

用堆解决TOP-K问题

代码实现 


何为TOP-K问题

比如:整个专业的前10名,世界500强,富豪榜,游戏中前100的活跃玩家等

对于 TOP-K 问题,能想到的最简单直接的方式就是排序
但是,如果数据量非常大,排序就不太可取了(可能数据都不能一下子全部加载到内存中)
最佳就是用堆来解决


用堆解决TOP-K问题

1. 用数据集合中前K个元素来建堆

  • 找前K个最大的元素时,就建小堆
  • 找前K个最小的元素时,就建大堆

2. 用剩余的N-K个元素依次与堆顶元素来比较,不满足则替换堆顶元素

  • 将剩余N-K个元素依次与堆顶元素比完之后
  • 堆中剩余的K个元素就是所求的前K个最大或者最小的元素

代码实现

生成一个数组,并且随机放入 10 个最大的值:

void TestTopK()
{
	// 动态申请 10000 个 int 类型的数组
	int n = 10000;
	int* a = (int*)malloc(sizeof(int) * n);

	// 判断是否申请成功
	if (a == NULL)
	{
		perror("malloc fail");
		return;
	}
	
	// 随机值生成器
	srand((unsigned int)time(NULL));

	// 在数组中依次存放小于 100000 的值
	for (int i = 0; i < n; i++)
	{
		a[i] = rand() % 100000;
	}

	// 将数组中的 10 个元素改为大于或者等于 100000 的值
	a[5] = 100000 + 1;
	a[123] = 100000 + 9;
	a[531] = 100000 + 2;
	a[4121] = 100000 + 8;
	a[115] = 100000 + 3;
	a[2335] = 100000 + 7;
	a[9999] = 100000 + 4;
	a[76] = 100000 + 6;
	a[423] = 100000 + 5;
	a[3144] = 100000 + 0;

	// 找出数组 a 中前 10 个最大的值,并打印
	PrintTopK(a, n, 10);
}

找出前 10 个最大的值:

void PrintTopK(int* a, int size, int k)
{
	for (int i = (k - 1 - 1) / 2; i >= 0; i--)
	{
		// 向下调整建堆(建小堆)
		AdjustDown(a, k, i);
	}

	for (int i = k; i < size; i++)
	{
		// 当前堆顶元素小于当前数组元素时就交换
		if (a[0] < a[i])
		{
			Swap(&a[0], &a[i]);

			// 向下调整堆
			AdjustDown(a, k, 0);
		}
	}

	ArrPrint(a, k);
}

代码解析(代码中的函数实现会放在最后,先讲解思路):

想要找到数组 a 中的前 10 个最大的数,那么就先建立大小为 10 的小堆
利用向下调整算法对数组 a 中的前 10 个数进行建堆,注意是建小堆
再把数组中剩余的元素依次与堆顶元素比较,当堆顶元素小于数组当前元素时,就交换
因为小堆的特点是:堆顶的元素是整个堆中元素最小的
那么堆中最小元素和数组当前元素比较时,还要小,那么堆顶元素必然不是前 10 个最大的数
交换后再利用向下调整算法调整堆
遍历完数组 a 中的所有元素后,堆中的元素就是前 10 个最大的数

代码验证:

                                               100000
                                      /                               \
                         100002                                   100001
                        /             \                                /            \
            100003              100005         100004              100006
            /         \                 /
   100007 100009      100008 

代码中的函数实现:

// 向下调整(默认小堆)
void AdjustDown(HPDataType* a, int size, int parent)
{
	int child = parent * 2 + 1;

	while (child < size)
	{
		// 先找到左右孩子中小的那个
		if ((child + 1 < size) && (a[child + 1] < a[child]))
			child++;

		if (a[parent] > a[child])
		{
			// 交换
			Swap(&a[parent], &a[child]);
			
			// 迭代
			parent = child;
			child = parent * 2 + 1;
		}
		else
		{
			break;
		}
	}
}

// 交换
void Swap(HPDataType* p1, HPDataType* p2)
{
	HPDataType tmp = *p1;
	*p1 = *p2;
	*p2 = tmp;
}

// 打印
void ArrPrint(int* a, int size)
{
	for (int i = 0; i < size; i++)
	{
		printf("%d ", a[i]);
	}
	printf("\n");
}