Implementing Red black Trees with property-based testing

2020-03-20 • edited 2022-09-14

In languages like Ocaml, Haskell, type systems will often help you along the way. Catch errors, warn incomplete guards. In the end, maintain invariants so that minimize the changes of presence of ill-formed actions of the running program.

In Python, we don't have that luxury, however, with frameworks such as Hypothesis and the alike, we can use the test framework to help us discover and design our software.

We will explore red-black tree as an example, a well-known data structure which balance itself in operations.

The motivation behind a self-balanced binary search tree is if inputs are not randomly enough, the BST may degenerate into a list, and lose the performance gain.

e.g.

 1 
  \
   2
    \ 
     3
      \

The search function of a Red-black BST is not much different than a normal basic BST, what makes red-black BST stablely balanced is its additional fixes on insertion and deletion for maintaining the invaraints.

The proof of why Red Black tree is balanced can be referred at [3]

You can get all the code at the repo.

Laying the foundation

We'll first just leave the insert function exaclty like the normal binary search tree.

from enum import Enum

class Color(Enum):
    RED = 0
    BLACK = 1

    def __str__(self):
        #: R | B
        return str(self.name)[0]


RED = Color.RED
BLACK = Color.BLACK


class Node:
    def __init__(self, key):
        self.key = key
        self.color = RED
        self.left = None
        self.right = None
        self.parent = None


class RBTree:
    def __init__(self):
        self.root = None

    def get_color(self, node):
        #: nils are black
        if node is None:
            return BLACK
        else:
            return node.color

    def in_order_walk(self, f):
        t = self.root

        def aux(t):
            if t is not None:
                aux(t.left)
                f(t)
                aux(t.right)

        aux(t)


    def search(self, key):
        current = self.root
        while current is not None and key != current.key:
            if key < current.key:
                current = current.left
            else:
                current = current.right
        return current


    def insert(self, key):

        new_node = Node(key)

        current = self.root

        if current is None:
            new_node.color = BLACK
            self.root = new_node
            return

        else:
            while True:
        		if key == current.key:
        			return current

                if key > current.key:
                    if current.right is None:
                        current.right = new_node
                        new_node.parent = current
                        break
                    current = current.right
                else:
                    if current.left is None:
                        current.left = new_node
                        new_node.parent = current
                        break
                    current = current.left

        return new_node


Adding representation

In order to see where the tree can go wrong, we better have a representation function that visualize the tree.

The code for printing trees is adpated from MIT 6.006's sample code on binary search tree.

	class RBTree:
		...

	def __str__(self):
        """
        Adapted from 
        MIT 6.006 reading section, binary search tree example
        """

        if self.root is None:
            return "<B:empty>"

        def aux(node):
            if node is None:
                return [], 0, 0
            ...
            #: More details are in the repo.

Now, let's enforce invariants

Invariants of red black binary search tree:

  • The nodes are either black or red.
  • The root of the red-black tree is black, and leaves(nil) are black.
  • The children of a red node are black.
  • Every path from a given node to any of its descendant NIL nodes goes through the same number of black nodes.
import unittest

import hypothesis.strategies as st
from hypothesis.database import DirectoryBasedExampleDatabase
from hypothesis.stateful import RuleBasedStateMachine, rule, precondition, invariant

from rbt import RBTree, Color, Node


class RBTreeMachine(RuleBasedStateMachine):
    """
    Invariants:
        0. The nodes are either black or red.
        1. The root of the red-black tree is black, and leaves(nil) are black.
        2. The children of a red node are black.
        3. Every path from a given node to any of its
           descendant NIL nodes goes through the same number of black nodes.
    """

    def __init__(self):
        super().__init__()
        self.tree = RBTree()

    @rule(key=st.integers())
    def insert(self, key):
        self.tree.insert(key)
        assert self.tree.search(key) is not None

    @precondition(lambda self: self.tree.root is not None)
    @invariant()
    def node_is_either_red_or_black(self):
        #: Invariant 0
        def f(node):
            assert node.color in Color, (
                "nodes should either be red or black \n %s" % self.tree
            )

        self.tree.in_order_walk(f)

    @precondition(lambda self: self.tree.root is not None)
    @invariant()
    def root_is_black(self):
        #: Invariant 1
        assert self.tree.root.color == Color.BLACK, (
            "root has to be black\n %s" % self.tree
        )

    @precondition(lambda self: self.tree.root is not None)
    @invariant()
    def red_nodes_have_black_children(self):
        #: Invariant 2
        def f(node):
            if node.color is Color.RED:
                if node.left is not None:
                    assert node.left.color is Color.BLACK, (
                        "red nodes can only have black children(left)\n %s" % self.tree
                    )
                if node.right is not None:
                    assert node.right.color is Color.BLACK, (
                        "red nodes can only have black children(right)\n %s" % self.tree
                    )

        self.tree.in_order_walk(f)

    @precondition(lambda self: self.tree.root is not None)
    @invariant()
    def black_heights_are_the_same(self):
        #: Invariant 3
        def aux(t):
            if t is None:
                return 1
            left, right = aux(t.left), aux(t.right)
            assert left == right, (
                "left's black height should equal to right's \n %s" % self.tree
            )
            return left + (1 if t.color is Color.BLACK else 0)

        bh = aux(self.tree.root)
        assert bh > 0

TestTrees = RBTreeMachine.TestCase

if __name__ == "__main__":
    unittest.main()

Developing with failures

We'll run the tests run and see what we can do to fix the tree.

python tests.py

Hypothesis complained falsfying examples that failed our invaraints.

AssertionError: red nodes can only have black children
 <B:-1>
/    \
    <R:1>
    /   \
  <R:0>
  /   \

AssertionError: red nodes can only have black children

 <B:0>
/   \
   <R:1>
   /   \
      <R:2>
      /   \

To solve this problem, however, we need to introduce the concept of rotation of a binary tree.

     Q                                                P   
   /   \          right rotation                    /   \
  P     C         -------------->                  A     Q   
 / \              <-------------                        / \
A   B             left rotation                        B   C

    (I)                                               (II)

The idea of a rotation of trees is you can change its structure without violating its order. If you do an in-order walk of both (I) and (II), you will get the same (A->P->B->Q->C).

Rotations just involve several pointer exchanges, therefore that's constant time.

... #: continue with class RBTree
def left_rotate(self, node: Node):
    sibling = node.right

    #: move sibling's left subtree into node's right subtree
    node.right = sibling.left
    if sibling.left is not None:
        sibling.left.parent = node

    #: link sibling's parent to node's parent
    sibling.parent = node.parent

    if node.parent is None:
        self.root = sibling
    else:
        if node.parent.left is node:
            node.parent.left = sibling
        else:
            node.parent.right = sibling

    sibling.left = node
    node.parent = sibling
#: right_rotate is similar

Case I

Back to the failing examples Hypothesis gave us above.

 <B:-1>                    <R:0>                 <B:0>
/    \                    /     \                /   \
    <R:0>       ->     <B:-1>  <R:1>    ->   <B:-1>  <R:1>
    /   \              /  \     /  \          / \     /  \  
       <R:1>
       /   \

The tatic is we rotate the inserted node's grandparent and re-color the new parent to be black.

class RBTree:
    ...
    def insert_fix(self, node):
        #: we only need to care about red parents and cases that the tree is not empty.
        while node.parent.color == RED and node != self.root:
            if node.parent == node.parent.parent.right:
                uncle = node.parent.parent.left
                if self.get_color(uncle) == BLACK:
                    node.parent.color = BLACK
                    node.parent.parent.color = RED
                    self.left_rotate(node.parent.parent)

                else:
                    #: we skip it for now
                    return

            else:
                #: skip it for now
                return

Case II

And now we re-run the tests

python tests.py
AssertionError: red nodes can only have black children(left)
 <B:-1>
/    \
    <R:1>
    /   \
  <R:0>
  /   \

Now we have a case which is different than the above, it forms a triangle instead of a line. How do we deal with it? We'll just do a right-rotate on its parent, then we'll successfully reduce the case to case I.

 <B:-1>                           <B:-1>
/    \                             /  \
    <R:1>  (case II)      ->         <R:0>   (case I)  
    /   \                             /  \
  <R:0>                                 <R:1>             
  /   \                                  /  \               

We change our insert fix to

class RBTree:
    ...
    def insert_fix(self, node):
        #: we only need to care about red parents and cases that the tree is not empty.
        while node.parent.color == RED and node != self.root:
            if node.parent == node.parent.parent.right:
                uncle = node.parent.parent.left

                if self.get_color(uncle) == BLACK:
                    if node == node.parent.left:
                        #: Case II
                        node = node.parent
                        self.right_rotate(node)

                    #: Case I
                    node.parent.color = BLACK
                    node.parent.parent.color = RED
                    self.left_rotate(node.parent.parent)
                else:
                    return

            else:
                #: skip it for now
                return

And we re-run our tests

AssertionError: red nodes can only have black children(right)
     <B:1>
   /     \
<R:-1>
/    \
    <R:0>
    /   \

We have a new false case, it's the mirror instance of case II above. We'll copy and modify the solution above.

class RBTree:
    ...
    def insert_fix(self, node):
        #: we only need to care about red parents and cases that the tree is not empty.
        while node.parent.color == RED and node != self.root:
            if node.parent == node.parent.parent.right:
                uncle = node.parent.parent.left

                if self.get_color(uncle) == BLACK:
                    if node == node.parent.left:
                        #: Case II
                        node = node.parent
                        self.right_rotate(node)

                    #: Case I
                    node.parent.color = BLACK
                    node.parent.parent.color = RED
                    self.left_rotate(node.parent.parent)

                else:
                    return

            else:
                uncle = node.parent.parent.right
                if self.get_color(uncle) == BLACK:
                    if node == node.parent.right:
                        #: Case II
                        node = node.parent
                        self.left_rotate(node)
                    #: case I
                    node.parent.color = BLACK
                    node.parent.parent.color = RED
                    self.right_rotate(node.parent.parent)

                else:
                    return

Case III

Then we run our tests.

AssertionError: red nodes can only have black children(left)
       <B:0>
      /    \
   <R:-1> <R:1>
   /    \ /   \
<R:-2>
/    \

A new case which isn't forming a line or a kink, instead, the node's parent and uncle are both red. The strategy for now, is try to swap the color of the insertion node's parent and grandparent.

...
else:
    #: uncle is red, case III
    node.parent.color = BLACK
    uncle.color = BLACK
    node.parent.parent.color = RED

We therefore run our tests again

AssertionError: root has to be black

Ah, you might have guessed this already, when we swaped the color, we made our root red which violates the first invariant.

Let's add the fix.

def insert_fix(...)
    ...
    self.root.color = BLACK  # final fix

Then we run the tests again

Ran 1 test in 4.117s

OK

It seems we finished our task.

Conclusion

That's pretty much it. The key takeaway is, translating invariants into code is relatively easy. It's often declarative. And you can utilize the test framework to find out what kind of cases you need to consider when you implement your solutions.

There are quite a few places can be improved, as you can augment the tree by putting extra information on it to save some computation. If you are interested, you can also try to implement deletion.

references and further readings

  1. https://en.wikipedia.org/wiki/Red%E2%80%93black_tree
  2. https://hypothesis.works/articles/rule-based-stateful-testing/
  3. https://www.codesdope.com/course/data-structures-red-black-trees/
  4. https://brilliant.org/wiki/red-black-tree/
#algorithms#Python#testing#tdd

Generating Primes