自己实现一个简单的AVL树

时间:2021-12-22 12:56:20

需求

自己实现一个简单的AVL-Tree。

包含功能

  • insert
  • remove
  • findMax
  • findMin
  • toString

  基本功能与之前实现的BinarySearchTree相同,不过avl树需要保证树的平衡,每一个节点的左右子树的高度差不超过1.所以在每次插入或删除的时候要对树进行调整。

代码

import java.util.NoSuchElementException;
import java.util.StringJoiner;

public class AvlTree<T extends Comparable<? super T>> {
private AvlNode<T> root;
private static final int ALLOWED_IMBLANCE = 1;

public AvlTree() {
}

public AvlTree(T[] arr) {
for (T t : arr) {
insert(t);
}
}

public void makeEmpty() {
root = null;
}

public boolean isEmpty() {
return root == null;
}

public void insert(T t) {
root = insert(t, root);
}

public T findMax() {
return findMax(root).element;
}

public T findMin() {
return findMin(root).element;
}

public void remove(T t) {
root = remove(t, root);
}

private AvlNode<T> findMax(AvlNode<T> node) {
if (node == null)
throw new NoSuchElementException();
if (node.right != null)
return findMax(node.right);
return node;
}

private AvlNode<T> findMin(AvlNode<T> node) {
if (node == null)
throw new NoSuchElementException();
while (node.left != null)
node = node.left;
return node;
}

private int height(AvlNode<T> node) {
return node == null ? -1 : node.height;
}

private int heightDiff(AvlNode<T> a, AvlNode<T> b) {
return height(a) - height(b);
}

private AvlNode<T> insert(T t, AvlNode<T> node) {
if (node == null)
return new AvlNode<T>(t);
int cmpRes = t.compareTo(node.element);
if (cmpRes < 0)
node.left = insert(t, node.left);
else if (cmpRes > 0)
node.right = insert(t, node.right);
return balance(node);
}

private AvlNode<T> remove(T t, AvlNode<T> node) {
if (node == null)
return node;
int cmpRes = t.compareTo(node.element);
if (cmpRes < 0) {
node.left = remove(t, node.left);
} else if (cmpRes > 0) {
node.right = remove(t, node.right);
} else if (node.left != null && node.right != null) {
node.element = findMin(node.right).element;
node.right = remove(node.element, node.right);
} else
node = node.left != null ? node.left : node.right;
return balance(node);
}

private AvlNode<T> balance(AvlNode<T> node) {
if (node != null) {
if (heightDiff(node.left, node.right) > ALLOWED_IMBLANCE) {
if (heightDiff(node.left.left, node.left.right) >= 0)
node = rotateWithLeftChild(node);
else
node = doubleWithLeftChild(node);
} else if (heightDiff(node.right, node.left) > ALLOWED_IMBLANCE) {
if (heightDiff(node.right.right, node.right.left) >= 0)
node = rotateWithRightChild(node);
else
node = doubleWithRightChild(node);
}
node.height = Math.max(height(node.left), height(node.right)) + 1;
}
return node;
}

private AvlNode<T> rotateWithLeftChild(AvlNode<T> n1) {
AvlNode<T> n2 = n1.left;
n1.left = n2.right;
n2.right = n1;
n1.height = Math.max(height(n1.left), height(n1.right)) + 1;
n2.height = Math.max(height(n2.left), n1.height) + 1;
return n2;
}

private AvlNode<T> rotateWithRightChild(AvlNode<T> n1) {
AvlNode<T> n2 = n1.right;
n1.right = n2.left;
n2.left = n1;
n1.height = Math.max(height(n1.left), height(n1.right)) + 1;
n2.height = Math.max(height(n2.right), n1.height) + 1;
return n2;
}

private AvlNode<T> doubleWithLeftChild(AvlNode<T> node) {
node.left = rotateWithRightChild(node.left);
return rotateWithLeftChild(node);
}

private AvlNode<T> doubleWithRightChild(AvlNode<T> node) {
node.right = rotateWithLeftChild(node.right);
return rotateWithRightChild(node);
}

@Override
public String toString() {
StringJoiner joiner = new StringJoiner(" ", "[", "]");
listAll(this.root, joiner);
return joiner.toString();
}

private void listAll(AvlNode<T> node, StringJoiner joiner) {
showNode(node);
if (node.left != null)
listAll(node.left, joiner);
joiner.add(node.element.toString());
if (node.right != null)
listAll(node.right, joiner);
}

private void showNode(AvlNode<T> node) {
System.out.printf("本节点:%s [左子:%s,右子:%s]\n", nodeValue(node), nodeValue(node.left), nodeValue(node.right));
}

private String nodeValue(AvlNode<T> node) {
if (node == null)
return "null";
return node.element.toString();
}

private static class AvlNode<T> {
T element;
AvlNode<T> left;
AvlNode<T> right;
int height;

public AvlNode(T element) {
this(element, null, null);
}

public AvlNode(T element, AvlNode<T> left, AvlNode<T> right) {
this.element = element;
this.left = left;
this.right = right;
this.height = 0;
}
}
}