// based on: https://algs4.cs.princeton.edu/33balanced/RedBlackBST.java.html
// TODO: implement NavigableSet
sclass MegaCompactTreeSet extends AbstractSet {
// A symbol table implemented using a left-leaning red-black BST.
// This is the 2-3 version.
private static final boolean RED = true;
private static final boolean BLACK = false;
private Node root; // root of the BST
int size; // size of tree set
// BST helper node data type
abstract sclass Node {
A val; // associated data
Node left() { null; } // get left subtree
abstract Node setLeft(Node left); // set left subtree - return potentially replaced node
Node right() { null; } // get right subtree
abstract Node setRight(Node right); // set right subtree - return potentially replaced node
abstract bool color();
abstract Node convertToBlack();
abstract Node convertToRed();
abstract Node invertColor();
Node convertToColor(bool color) { ret color == RED ? convertToRed() : convertToBlack(); }
bool isLeaf() { ret left() == null && right() == null; }
}
asclass NonLeaf extends Node {
Node left, right;
Node left() { ret left; }
Node setLeft(Node left) {
this.left = left;
if (left == null && right() == null) ret newLeaf(color(), val);
this;
}
Node right() { ret right; }
Node setRight(Node right) {
this.right = right;
if (right == null && left() == null) ret newLeaf(color(), val);
this;
}
}
sclass BlackNode extends NonLeaf {
*(A *val, Node *left, Node *right) {}
bool color() { ret BLACK; }
Node convertToBlack() { this; }
Node convertToRed() { ret new RedNode(val, left, right); }
Node invertColor() { ret convertToRed(); }
}
sclass RedNode extends NonLeaf {
*(A *val, Node *left, Node *right) {}
bool color() { ret RED; }
Node convertToBlack() { ret new BlackNode(val, left, right); }
Node convertToRed() { this; }
Node invertColor() { ret convertToBlack(); }
}
sclass BlackLeaf extends Node {
*(A *val) {}
Node setLeft(Node left) {
ret new BlackNode(val, left, null);
}
Node setRight(Node right) {
ret new BlackNode(val, null, right);
}
bool color() { ret BLACK; }
Node convertToBlack() { this; }
Node convertToRed() { ret new RedLeaf(val); }
Node invertColor() { ret convertToRed(); }
}
sclass RedLeaf extends Node {
*(A *val) {}
Node setLeft(Node left) {
ret new RedNode(val, left, null);
}
Node setRight(Node right) {
ret new RedNode(val, null, right);
}
bool color() { ret RED; }
Node convertToBlack() { ret new BlackLeaf(val); }
Node convertToRed() { this; }
Node invertColor() { ret convertToBlack(); }
}
*() {}
*(Cl extends A> cl) { addAll(cl); }
// returns false on null (algorithm needs this)
static bool isRed(Node x) {
ret x != null && x.color() == RED;
}
static Node newLeaf(bool color, A val) {
ret color == RED ? new RedLeaf(val) : new BlackLeaf(val);
}
public int size() {
ret size;
}
public bool isEmpty() {
ret root == null;
}
public bool add(A val) {
int oldSize = size;
root = put(root, val);
root = root.convertToBlack();
ifdef CompactTreeSet_debug assertTrue(check()); endifdef
ret size > oldSize;
}
// insert the value in the subtree rooted at h
private Node put(Node h, A val) {
if (h == null) { ++size; ret new RedLeaf(val); }
int cmp = compare(val, h.val);
if (cmp < 0) h = h.setLeft(put(h.left(), val));
else if (cmp > 0) h = h.setRight(put(h.right(), val)) ;
else { /*h.val = val;*/ } // no overwriting
// fix-up any right-leaning links
if (isRed(h.right()) && !isRed(h.left())) h = rotateLeft(h);
if (isRed(h.left()) && isRed(h.left().left())) h = rotateRight(h);
if (isRed(h.left()) && isRed(h.right())) h = flipColors(h);
ret h;
}
// override me if you wish
int compare(A a, A b) {
ret cmp(a, b);
}
public bool remove(O key) {
if (!contains(key)) false;
// if both children of root are black, set root to red
if (!isRed(root.left()) && !isRed(root.right()))
root = root.convertToRed();
root = delete(root, (A) key);
if (!isEmpty()) root = root.convertToBlack();
// assert check();
true;
}
// delete the key-value pair with the given key rooted at h
private Node delete(Node h, A key) {
// assert get(h, key) != null;
if (compare(key, h.val) < 0) {
if (!isRed(h.left()) && !isRed(h.left().left()))
h = moveRedLeft(h);
h = h.setLeft(delete(h.left(), key));
}
else {
if (isRed(h.left()))
h = rotateRight(h);
if (compare(key, h.val) == 0 && (h.right() == null)) {
--size; null;
} if (!isRed(h.right()) && !isRed(h.right().left()))
h = moveRedRight(h);
if (compare(key, h.val) == 0) {
--size;
Node x = min(h.right());
h.val = x.val;
// h.val = get(h.right(), min(h.right()).val);
// h.val = min(h.right()).val;
h = h.setRight(deleteMin(h.right()));
}
else h = h.setRight(delete(h.right(), key));
}
return balance(h);
}
// make a left-leaning link lean to the right
private Node rotateRight(Node h) {
// assert (h != null) && isRed(h.left());
Node x = h.left();
h = h.setLeft(x.right());
x = x.setRight(h);
x = x.convertToColor(x.right().color());
x = x.setRight(x.right().convertToRed());
ret x;
}
// make a right-leaning link lean to the left
private Node rotateLeft(Node h) {
// assert (h != null) && isRed(h.right());
Node x = h.right();
h = h.setRight(x.left());
x = x.setLeft(h);
x = x.convertToColor(x.left().color());
x = x.setLeft(x.left().convertToRed());
ret x;
}
// flip the colors of a node and its two children
private Node flipColors(Node h) {
// h must have opposite color of its two children
// assert (h != null) && (h.left() != null) && (h.right() != null);
// assert (!isRed(h) && isRed(h.left()) && isRed(h.right()))
// || (isRed(h) && !isRed(h.left()) && !isRed(h.right()));
h = h.setLeft(h.left().invertColor());
h = h.setRight(h.right().invertColor());
ret h.invertColor();
}
// Assuming that h is red and both h.left() and h.left().left()
// are black, make h.left() or one of its children red.
private Node moveRedLeft(Node h) {
// assert (h != null);
// assert isRed(h) && !isRed(h.left()) && !isRed(h.left().left());
h = flipColors(h);
if (isRed(h.right().left())) {
h = h.setRight(rotateRight(h.right()));
h = rotateLeft(h);
h = flipColors(h);
}
ret h;
}
// Assuming that h is red and both h.right() and h.right().left()
// are black, make h.right() or one of its children red.
private Node moveRedRight(Node h) {
// assert (h != null);
// assert isRed(h) && !isRed(h.right()) && !isRed(h.right().left());
h = flipColors(h);
if (isRed(h.left().left())) {
h = rotateRight(h);
h = flipColors(h);
}
ret h;
}
// restore red-black tree invariant
private Node balance(Node h) {
// assert (h != null);
if (isRed(h.right())) h = rotateLeft(h);
if (isRed(h.left()) && isRed(h.left().left())) h = rotateRight(h);
if (isRed(h.left()) && isRed(h.right())) h = flipColors(h);
ret h;
}
/**
* Returns the height of the BST (for debugging).
* @return the height of the BST (a 1-node tree has height 0)
*/
public int height() {
ret height(root);
}
private int height(Node x) {
if (x == null) return -1;
return 1 + Math.max(height(x.left()), height(x.right()));
}
public bool contains(O val) {
ret find(root, (A) val) != null;
}
public A find(A probeVal) {
Node n = find(root, probeVal);
ret n == null ? null : n.val;
}
// value associated with the given key in subtree rooted at x; null if no such key
private A get(Node x, A key) {
x = find(x, key);
ret x == null ? null : x.val;
}
Node find(Node x, A key) {
while (x != null) {
int cmp = compare(key, x.val);
if (cmp < 0) x = x.left();
else if (cmp > 0) x = x.right();
else ret x;
}
null;
}
private boolean check() {
if (!is23()) println("Not a 2-3 tree");
if (!isBalanced()) println("Not balanced");
return is23() && isBalanced();
}
// Does the tree have no red right links, and at most one (left)
// red links in a row on any path?
private boolean is23() { return is23(root); }
private boolean is23(Node x) {
if (x == null) true;
if (isRed(x.right())) false;
if (x != root && isRed(x) && isRed(x.left())) false;
ret is23(x.left()) && is23(x.right());
}
// do all paths from root to leaf have same number of black edges?
private bool isBalanced() {
int black = 0; // number of black links on path from root to min
Node x = root;
while (x != null) {
if (!isRed(x)) black++;
x = x.left();
}
ret isBalanced(root, black);
}
// does every path from the root to a leaf have the given number of black links?
private boolean isBalanced(Node x, int black) {
if (x == null) return black == 0;
if (!isRed(x)) black--;
return isBalanced(x.left(), black) && isBalanced(x.right(), black);
}
public void clear() { root = null; size = 0; }
// the smallest key in subtree rooted at x; null if no such key
private Node min(Node x) {
// assert x != null;
while (x.left() != null) x = x.left();
ret x;
}
private Node deleteMin(Node h) {
if (h.left() == null)
return null;
if (!isRed(h.left()) && !isRed(h.left().left()))
h = moveRedLeft(h);
h = h.setLeft(deleteMin(h.left()));
ret balance(h);
}
public Iterator iterator() {
ret new MyIterator;
}
class MyIterator extends ItIt {
new L> path;
*() {
fetch(root);
}
void fetch(Node node) {
while (node != null) {
path.add(node);
node = node.left();
}
}
public bool hasNext() { ret !path.isEmpty(); }
public A next() {
if (path.isEmpty()) fail("no more elements");
Node node = popLast(path);
// last node is always a leaf, so left is null
// so proceed to fetch right branch
fetch(node.right());
ret node.val;
}
}
// Returns the smallest key in the symbol table greater than or equal to {@code key}.
public A ceiling(A key) {
Node x = ceiling(root, key);
ret x == null ? null : x.val;
}
// the smallest key in the subtree rooted at x greater than or equal to the given key
Node ceiling(Node x, A key) {
if (x == null) null;
int cmp = compare(key, x.val);
if (cmp == 0) ret x;
if (cmp > 0) ret ceiling(x.right(), key);
Node t = ceiling(x.left(), key);
if (t != null) ret t;
else ret x;
}
public A floor(A key) {
Node x = floor(root, key);
ret x == null ? null : x.val;
}
// the largest key in the subtree rooted at x less than or equal to the given key
Node floor(Node x, A key) {
if (x == null) null;
int cmp = compare(key, x.val);
if (cmp == 0) ret x;
if (cmp < 0) ret floor(x.left(), key);
Node t = floor(x.right(), key);
if (t != null) ret t;
else ret x;
}
void testInternalStructure(Node node default root) {
if (node == null) ret;
assertTrue(className(node), !node.isLeaf() == node instanceof NonLeaf);
testInternalStructure(node.left());
testInternalStructure(node.right());
}
}