sclass PtTree { Node root = new Leaf(null); abstract class Node { Node parent; abstract bool add(Pt p); abstract void collectPointsIn(Rect r, L out); abstract void replaceChild(Node from, Node to); void addAll(Iterable l) { fOr (Pt p : l) add(p); } } class Leaf extends Node { Cl points; *(Node *parent) {} bool add(Pt p) { if (main contains(points, p)) false; if (points == null) points = new L; points.add(p); possiblySplit(); true; } void possiblySplit { if (l(points) > maxPointsPerNode) split(); } void collectPointsIn(Rect r, L out) { main addAll(out, points); } record noeq ProposedSplit(int dimension, int splitPoint, int count) { int error() { ret abs(count-half(l(points))); } } Split split() { int n = l(points), half = n/2; ProposedSplit xSplit = checkSplit(0); ProposedSplit ySplit = checkSplit(1); var best = xSplit.error() < ySplit.error() ? xSplit : ySplit; new Split split; split.dimension = (byte) best.dimension; split.splitPoint = best.splitPoint; Leaf a = new Leaf(split); IPred pred = p -> ptCoord(split.dimension, p) >= split.splitPoint; a.addAll(antiFilter(pred, points)); split.a = a; Leaf b = new Leaf(split); b.addAll(filter(pred, points)); split.b = b; ret replaceMeWith(split); } A replaceMeWith(A node) { node.parent = parent; if (parent != null) parent.replaceChild(this, node); else if (root == this) root = node; ret node; } ProposedSplit checkSplit(int dimension) { L lx = sortedBy(points, p -> ptCoord(dimension, p)); new ProposedSplit ps; ps.dimension = dimension; int n = l(points), half = n/2; int i = 0; int splitPoint = Int.MIN_VALUE; while true { int lastSplitPoint = splitPoint; splitPoint = ptCoord(lx.get(i++), dimension)+1; int lastI = i; while (i < n && ptCoord(lx.get(i), dimension) < splitPoint) ++i; if (i >= half) { if (abs(lastI-half) < abs(i-half)) { ps.splitPoint = lastSplitPoint; ps.count = lastI; } else { ps.splitPoint = splitPoint; ps.count = i; } ret ps; } } } void replaceChild(Node from, Node to) { unimplemented(); } // doesn't need } class Split extends Node { byte dimension; // 0 for X, 1 for Y int splitPoint; Node a, b; bool add(Pt p) { ret (ptCoord(p, dimension) >= splitPoint ? b : a).add(p); } void collectPointsIn(Rect r, L out) { a.collectPointsIn(r, out); b.collectPointsIn(r, out); } void replaceChild(Node from, Node to) { if (a == from) a = to; if (b == from) b = to; } } int maxPointsPerNode = 4; bool add(Pt p) { ret root.add(p); } L pointsIn(Rect r) { new L out; root.collectPointsIn(r, out); ret out; } bool contains(Pt p) { ret nempty(pointsIn(rect(p.x, p.y, 1, 1))); } // use this to make a PtTree! // It's actually makes a balanced tree. // It is assumed that points contains no duplicates static PtTree fromPointSet(Iterable points) { new PtTree tree; Leaf root = cast tree.root; root.points = cloneList(points); root.possiblySplit(); ret tree; } }