// based on: https://algs4.cs.princeton.edu/33balanced/RedBlackBST.java.html
// TODO: implement NavigableSet
sclass UltraCompactTreeSet extends AbstractSet {
// A symbol table implemented using a left-leaning red-black BST.
// This is the 2-3 version.
private static final boolean RED = true;
private static final boolean BLACK = false;
private Node root; // root of the BST
int size; // size of tree set
// BST helper node data type
abstract sclass Node {
A val; // associated data
Node left, right; // links to left and right subtrees
abstract bool color();
abstract BlackNode convertToBlack();
abstract RedNode convertToRed();
abstract Node invertColor();
Node convertToColor(bool color) { ret color == RED ? convertToRed() : convertToBlack(); }
}
sclass BlackNode extends Node {
*(A *val) {}
*(A *val, Node *left, Node *right) {}
bool color() { ret BLACK; }
BlackNode convertToBlack() { this; }
RedNode convertToRed() { ret new RedNode(val, left, right); }
Node invertColor() { ret convertToRed(); }
}
sclass RedNode extends Node {
*(A *val) {}
*(A *val, Node *left, Node *right) {}
bool color() { ret RED; }
BlackNode convertToBlack() { ret new BlackNode(val, left, right); }
RedNode convertToRed() { this; }
Node invertColor() { ret convertToBlack(); }
}
*() {}
*(Cl extends A> cl) { addAll(cl); }
// is node x red; false if x is null ?
static bool isRed(Node x) {
ret x instanceof RedNode;
}
public int size() {
ret size;
}
public bool isEmpty() {
ret root == null;
}
public bool add(A val) {
int oldSize = size;
root = put(root, val);
root = root.convertToBlack();
ifdef CompactTreeSet_debug assertTrue(check()); endifdef
ret size > oldSize;
}
// insert the value in the subtree rooted at h
private Node put(Node h, A val) {
if (h == null) { ++size; ret new RedNode(val); }
int cmp = compare(val, h.val);
if (cmp < 0) h.left = put(h.left, val);
else if (cmp > 0) h.right = put(h.right, val);
else { /*h.val = val;*/ } // no overwriting
// fix-up any right-leaning links
if (isRed(h.right) && !isRed(h.left)) h = rotateLeft(h);
if (isRed(h.left) && isRed(h.left.left)) h = rotateRight(h);
if (isRed(h.left) && isRed(h.right)) h = flipColors(h);
ret h;
}
// override me if you wish
int compare(A a, A b) {
ret cmp(a, b);
}
public bool remove(O key) {
if (!contains(key)) false;
// if both children of root are black, set root to red
if (!isRed(root.left) && !isRed(root.right))
root = root.convertToRed();
root = delete(root, (A) key);
if (!isEmpty()) root = root.convertToBlack();
// assert check();
true;
}
// delete the key-value pair with the given key rooted at h
private Node delete(Node h, A key) {
// assert get(h, key) != null;
if (compare(key, h.val) < 0) {
if (!isRed(h.left) && !isRed(h.left.left))
h = moveRedLeft(h);
h.left = delete(h.left, key);
}
else {
if (isRed(h.left))
h = rotateRight(h);
if (compare(key, h.val) == 0 && (h.right == null)) {
--size; null;
} if (!isRed(h.right) && !isRed(h.right.left))
h = moveRedRight(h);
if (compare(key, h.val) == 0) {
--size;
Node x = min(h.right);
h.val = x.val;
// h.val = get(h.right, min(h.right).val);
// h.val = min(h.right).val;
h.right = deleteMin(h.right);
}
else h.right = delete(h.right, key);
}
return balance(h);
}
// make a left-leaning link lean to the right
private Node rotateRight(Node h) {
// assert (h != null) && isRed(h.left);
Node x = h.left;
h.left = x.right;
x.right = h;
x = x.convertToColor(x.right.color());
x.right = x.right.convertToRed();
ret x;
}
// make a right-leaning link lean to the left
private Node rotateLeft(Node h) {
// assert (h != null) && isRed(h.right);
Node x = h.right;
h.right = x.left;
x.left = h;
x = x.convertToColor(x.left.color());
x.left = x.left.convertToRed();
ret x;
}
// flip the colors of a node and its two children
private Node flipColors(Node h) {
// h must have opposite color of its two children
// assert (h != null) && (h.left != null) && (h.right != null);
// assert (!isRed(h) && isRed(h.left) && isRed(h.right))
// || (isRed(h) && !isRed(h.left) && !isRed(h.right));
h.left = h.left.invertColor();
h.right = h.right.invertColor();
ret h.invertColor();
}
// Assuming that h is red and both h.left and h.left.left
// are black, make h.left or one of its children red.
private Node moveRedLeft(Node h) {
// assert (h != null);
// assert isRed(h) && !isRed(h.left) && !isRed(h.left.left);
h = flipColors(h);
if (isRed(h.right.left)) {
h.right = rotateRight(h.right);
h = rotateLeft(h);
h = flipColors(h);
}
ret h;
}
// Assuming that h is red and both h.right and h.right.left
// are black, make h.right or one of its children red.
private Node moveRedRight(Node h) {
// assert (h != null);
// assert isRed(h) && !isRed(h.right) && !isRed(h.right.left);
h = flipColors(h);
if (isRed(h.left.left)) {
h = rotateRight(h);
h = flipColors(h);
}
ret h;
}
// restore red-black tree invariant
private Node balance(Node h) {
// assert (h != null);
if (isRed(h.right)) h = rotateLeft(h);
if (isRed(h.left) && isRed(h.left.left)) h = rotateRight(h);
if (isRed(h.left) && isRed(h.right)) h = flipColors(h);
ret h;
}
/**
* Returns the height of the BST (for debugging).
* @return the height of the BST (a 1-node tree has height 0)
*/
public int height() {
ret height(root);
}
private int height(Node x) {
if (x == null) return -1;
return 1 + Math.max(height(x.left), height(x.right));
}
public bool contains(O val) {
ret find(root, (A) val) != null;
}
public A find(A probeVal) {
Node n = find(root, probeVal);
ret n == null ? null : n.val;
}
// value associated with the given key in subtree rooted at x; null if no such key
private A get(Node x, A key) {
x = find(x, key);
ret x == null ? null : x.val;
}
Node find(Node x, A key) {
while (x != null) {
int cmp = compare(key, x.val);
if (cmp < 0) x = x.left;
else if (cmp > 0) x = x.right;
else ret x;
}
null;
}
private boolean check() {
if (!is23()) println("Not a 2-3 tree");
if (!isBalanced()) println("Not balanced");
return is23() && isBalanced();
}
// Does the tree have no red right links, and at most one (left)
// red links in a row on any path?
private boolean is23() { return is23(root); }
private boolean is23(Node x) {
if (x == null) true;
if (isRed(x.right)) false;
if (x != root && isRed(x) && isRed(x.left)) false;
ret is23(x.left) && is23(x.right);
}
// do all paths from root to leaf have same number of black edges?
private bool isBalanced() {
int black = 0; // number of black links on path from root to min
Node x = root;
while (x != null) {
if (!isRed(x)) black++;
x = x.left;
}
ret isBalanced(root, black);
}
// does every path from the root to a leaf have the given number of black links?
private boolean isBalanced(Node x, int black) {
if (x == null) return black == 0;
if (!isRed(x)) black--;
return isBalanced(x.left, black) && isBalanced(x.right, black);
}
public void clear() { root = null; size = 0; }
// the smallest key in subtree rooted at x; null if no such key
private Node min(Node x) {
// assert x != null;
while (x.left != null) x = x.left;
ret x;
}
private Node deleteMin(Node h) {
if (h.left == null)
return null;
if (!isRed(h.left) && !isRed(h.left.left))
h = moveRedLeft(h);
h.left = deleteMin(h.left);
ret balance(h);
}
public Iterator iterator() {
ret new MyIterator;
}
class MyIterator extends ItIt {
new L> path;
*() {
fetch(root);
}
void fetch(Node node) {
while (node != null) {
path.add(node);
node = node.left;
}
}
public bool hasNext() { ret !path.isEmpty(); }
public A next() {
if (path.isEmpty()) fail("no more elements");
Node node = popLast(path);
// last node is always a leaf, so left is null
// so proceed to fetch right branch
fetch(node.right);
ret node.val;
}
}
// Returns the smallest key in the symbol table greater than or equal to {@code key}.
public A ceiling(A key) {
Node x = ceiling(root, key);
ret x == null ? null : x.val;
}
// the smallest key in the subtree rooted at x greater than or equal to the given key
Node ceiling(Node x, A key) {
if (x == null) null;
int cmp = compare(key, x.val);
if (cmp == 0) ret x;
if (cmp > 0) ret ceiling(x.right, key);
Node t = ceiling(x.left, key);
if (t != null) ret t;
else ret x;
}
public A floor(A key) {
Node x = floor(root, key);
ret x == null ? null : x.val;
}
// the largest key in the subtree rooted at x less than or equal to the given key
Node floor(Node x, A key) {
if (x == null) null;
int cmp = compare(key, x.val);
if (cmp == 0) ret x;
if (cmp < 0) ret floor(x.left, key);
Node t = floor(x.right, key);
if (t != null) ret t;
else ret x;
}
}