在二叉搜索树(BST)中查找第K个大的结点之非递归实现

时间:2021-01-20 07:21:21

一个被广泛使用的面试题: 给定一个二叉搜索树,请找出其中的第K个大的结点。

PS:我第一次在面试的时候被问到这个问题而且让我直接在白纸上写的时候,直接蒙圈了,因为没有刷题准备,所以就会有伤害。知耻而后勇,于是我回家花了两个半小时(在不参考任何书本和网路上的源码的前提下),从构建BST开始,到实现中序遍历,最后用递归方法写出bst_findKthNode()并用gdb调试成功。 不过,使用递归实现这个实在是比较low,所以这个周末我决定用非递归方法实现。

先贴一下我的递归实现 (个人觉得比较low, 虽然实现了,但是不满意)

 /*
* Find the Kth Node in BST, K = 1, 2, ...
*/
int
bst_findKthNode(bst_node_t *root, key_t *key, unsigned int k)
{
if (root == NULL)
return -; if (root->left != NULL && k > )
k = bst_findKthNode(root->left, key, k); if (--k == ) {
*key = root->key;
return ;
} if (root->right != NULL && k > )
k = bst_findKthNode(root->right, key, k); return k;
}

下面的代码是我写的非递归实现。

 /*
* Find the Kth Node in BST, K = 1, 2, ...
*/
bst_node_t *
bst_findKthNode(bst_node_t *root, unsigned int k)
{
bst_node_t *kp = NULL; if (root == NULL)
return NULL; (void) stack_init(STACK_SIZE); while (root != NULL || !stack_isEmpty()) {
if (root != NULL) {
push((uintptr_t)root);
root = root->left;
continue;
} pop((uintptr_t *)(&root));
if (--k == ) {
kp = root;
break;
} root = root->right;
} stack_fini(); return kp;
}

使用Meld进行diff后的截图,

在二叉搜索树(BST)中查找第K个大的结点之非递归实现

注意: 题目请参见《剑指Offer》(何海涛著)面试题63: 二叉搜索树的第k个结点, 其cpp答案在这里

最后,贴出完整的代码和测试运行结果。

o libstack.h 和 libstack.c (参见 将递归函数非递归化的一般方法(cont) 一文)
o libbst.h

 #ifndef _LIBBST_H
#define _LIBBST_H #ifdef __cplusplus
extern "C" {
#endif #define STACK_SIZE 16 typedef int key_t; typedef struct bst_node_s {
key_t key;
struct bst_node_s *left;
struct bst_node_s *right;
} bst_node_t; int bst_init(bst_node_t **root, key_t a[], size_t n);
void bst_fini(bst_node_t *root);
void bst_walk(bst_node_t *root);
bst_node_t *bst_findKthNode(bst_node_t *root, unsigned int k); #ifdef __cplusplus
}
#endif #endif /* _LIBBST_H */

o libbst.c

 #include <stdio.h>
#include <stdlib.h>
#include "libbst.h"
#include "libstack.h" static int bst_add_node(bst_node_t **root, key_t key); int
bst_init(bst_node_t **root, key_t a[], size_t n)
{
*root = NULL;
for (int i = ; i < n; i++) {
if (bst_add_node(root, a[i]) != )
return -;
} return ;
} #define UMEM_FREE_PATTERN 0xdeadbeefdeadbeefULL
static inline void
BST_DESTROY_NODE(bst_node_t *p)
{
p->left = NULL;
p->right = NULL;
*(unsigned long long *)p = UMEM_FREE_PATTERN;
} void
bst_fini(bst_node_t *root)
{
if (root == NULL)
return; bst_fini(root->left);
bst_fini(root->right); BST_DESTROY_NODE(root);
free(root);
} static int
bst_add_node(bst_node_t **root, key_t key)
{
bst_node_t *leaf = NULL;
leaf = (bst_node_t *)malloc(sizeof (bst_node_t));
if (leaf == NULL) {
fprintf(stderr, "failed to malloc\n");
return -;
} /* init leaf node */
leaf->key = key;
leaf->left = NULL;
leaf->right = NULL; /* add leaf node to root */
if (*root == NULL) { /* root node does not exit */
*root = leaf;
} else {
bst_node_t **pp = NULL;
while () {
if (leaf->key < (*root)->key)
pp = &((*root)->left);
else
pp = &((*root)->right); if (*pp == NULL) {
*pp = leaf;
break;
} root = pp;
}
} return ;
} void
bst_walk(bst_node_t *root)
{
if (root == NULL)
return; (void) stack_init(STACK_SIZE); while (root != NULL || !stack_isEmpty()) {
if (root != NULL) {
push((uintptr_t)root);
root = root->left;
continue;
} pop((uintptr_t *)(&root));
printf("%d\n", root->key); root = root->right;
} stack_fini();
} /*
* Find the Kth Node in BST, K = 1, 2, ...
*/
bst_node_t *
bst_findKthNode(bst_node_t *root, unsigned int k)
{
bst_node_t *kp = NULL; if (root == NULL)
return NULL; (void) stack_init(STACK_SIZE); while (root != NULL || !stack_isEmpty()) {
if (root != NULL) {
push((uintptr_t)root);
root = root->left;
continue;
} pop((uintptr_t *)(&root));
if (--k == ) {
kp = root;
break;
} root = root->right;
} stack_fini(); return kp;
}

o foo.c (简单测试)

 #include <stdio.h>
#include <stdlib.h>
#include "libbst.h" int
main(int argc, char *argv[])
{
if (argc != ) {
fprintf(stderr, "Usage: %s <Kth>\n", argv[]);
return -;
} int a[] = {, , , , , , , , };
int n = sizeof (a) / sizeof (int); bst_node_t *root = NULL;
bst_init(&root, a, n); bst_walk(root); unsigned int k = atoi(argv[]);
bst_node_t *p = NULL;
if ((p = bst_findKthNode(root, k)) == NULL) {
printf("\nOops, the %dth node not found\n", k);
goto done;
}
printf("\nWell, the %dth node found, its key is %d\n", k, p->key); done:
bst_fini(root); return ;
}

o Makefile

 CC    = gcc
CFLAGS = -g -Wall -std=gnu99 -m32
INCS = TARGET = foo all: ${TARGET} foo: foo.o libstack.o libbst.o
${CC} ${CFLAGS} -o $@ $^ foo.o: foo.c
${CC} ${CFLAGS} -c $< ${INCS} libstack.o: libstack.c libstack.h
${CC} ${CFLAGS} -c $< libbst.o: libbst.c libbst.h
${CC} ${CFLAGS} -c $< clean:
rm -f *.o
clobber: clean
rm -f ${TARGET}

o 编译并测试运行

$ make
gcc -g -Wall -std=gnu99 -m32 -c foo.c
gcc -g -Wall -std=gnu99 -m32 -c libstack.c
gcc -g -Wall -std=gnu99 -m32 -c libbst.c
gcc -g -Wall -std=gnu99 -m32 -o foo foo.o libstack.o libbst.o $ ./foo Well, the 6th node found, its key is $ ./foo | egrep 'Oops,'
Oops, the 16th node not found
$

扩展题目: "寻找两个数组的中位数"。 题目描述如下:

有两个数组, 第一个数组a里的元素按照升序排列, e.g. int a[] = {10, 30, 40, 70, 80, 90};

第二个数组b里的元素按照降序排列, e.g. int b[] = {60, 50, 30, 20, 10};

请寻找数组a和b的合集的中位数,e.g. 50。

解决方案:

  • 使用数组a构建一个无重复key的BST
  • 将数组b里的元素加入BST (若某个元素已经在BST中存在,不予加入)
  • 设BST中的所有结点总数为N (a) 若N为偶数, 查找第K, K+1个元素 (K=N/2) 并求其平均值; (b) 若N为奇数, 查找第K+1个元素(K=N/2)。

关于此题目的详细描述和解决方案请参见 《剑指Offer》(何海涛著)面试题64: 数据流中的中位数