C++, Red Black Tree, Color violation fixup after insertion doesn't work properly

454 Views Asked by At

I'm implementing a Red Black Tree in C++ and I'm stuck on fixing color violations after insertion.

My left and right rotations seem to work fine but the colors in right branches of the tree are never getting fixed correctly. I think I covered all the cases in my fixViolation(Node*n) method but maybe I'm wrong.

I will also appreciate all other advice and tips on my code. pastebin link to my code

My code:

    #include "pch.h"
    #include <iostream>
    #include <random>
    #include <string>
    #include <functional>
    using namespace std;

enum Color { Black, Red };

template<typename T>
class RBTree;

template <typename T>
class Node
{
    friend class RBTree<T>;
    T data;
    Color color;
    Node* parent;
    Node* leftChild;
    Node* rightChild;
public:
    Node()
    {
        this->parent = nullptr;
        this->leftChild = nullptr;
        this->rightChild = nullptr;
        color = Red;
    }
    void gibColor(Node<T>*x)
    {
        if (x->color == Black)
            cout << "black";
        else
            cout << "red";
    }
};

template <typename T>
class RBTree
{
    Node<T>* root;
    int size;
public:
    RBTree()
    {
        root = nullptr;
        size = 1;
    }

    void leftRotate(Node<T>*child, Node<T>*parent)
    {
        child = parent->rightChild;
        parent->rightChild = child->leftChild;
        if (child->leftChild != nullptr)
        {
            child->leftChild->parent = parent;
        }

        child->parent = parent->parent;
        if (parent->parent == nullptr)
        {
            root = child;
        }
        else if (parent == parent->parent->leftChild)
        {
            parent->parent->leftChild = child;
        }
        else {
            parent->parent->rightChild = child;
        }

        child->leftChild = parent;
        parent->parent = child;
    }
    void rightRotate(Node<T>*child, Node<T>*parent)
    {
        child = parent->leftChild;
        parent->leftChild = child->rightChild;
        if (child->rightChild != nullptr)
        {
            child->rightChild->parent = parent;
        }
        child->parent = parent->parent;
        if (parent->parent == nullptr)
        {
            root = child;
        }
        else if (parent == parent->parent->rightChild)
        {
            parent->parent->rightChild = child;
        }
        else
        {
            parent->parent->leftChild = child;
        }

        child->rightChild = parent;
        parent->parent = child;
        }
        void fixViolation(Node<T>*n)
        {
            Node<T>*grandparent;
            Node<T>*parent;
            //Node<T>*uncle;
            parent = n->parent;
            while (parent != nullptr&& parent->color == Red)
            {
            Node<T>*uncle;
            grandparent = n->parent->parent;
            if (grandparent->leftChild == parent)
            {
                uncle = grandparent->rightChild;
                if (uncle != nullptr&&uncle->color == Red)
                {
                    parent->color = Black;
                    uncle->color = Black;
                    grandparent->color = Red;
                    n = grandparent;
                    parent = n->parent;
                }
                else
                {
                    if (parent->rightChild == n)
                    {
                        n = parent;
                        leftRotate(n->parent, n);
                    }

                    parent->color = Black;
                    grandparent->color = Red;
                    rightRotate(parent, grandparent);
                }
            }
            else
            {
                uncle = grandparent->leftChild;
                if (uncle != nullptr&&uncle->color == Red)
                {
                    uncle->color = Black;
                    parent->color = Black;
                    grandparent->color = Red;
                    n = grandparent;
                    parent = n->parent;
                }
                else
                {
                    if (parent->leftChild == n)
                    {
                        n = parent;
                        rightRotate(n->parent, n);
                    }

                    parent->color = Black;
                    grandparent->color = Red;
                    leftRotate(parent, grandparent);
                }
            }
        }
        root->color = Black;
    }


    void addElement(T el)
    {
        Node<T>*n = new Node<T>();
        n->data = el;
        n->leftChild = nullptr;
        n->rightChild = nullptr;
        n->color = Red;
        Node<T>*temp = this->root;
        Node<T>*y = nullptr;
        if (root == nullptr)
        {
            n->color = Black;
            root = n;
        }
        else
        {
            while (temp != nullptr)
            {
                y = temp;
                if (temp->data < n->data)
                {
                    temp = temp->rightChild;
                }
                else
                {
                    temp = temp->leftChild;
                }
            }
            n->parent = y;
            if (y->data == n->data)
            {
                cout << "Duplikaty won!" << endl;
                return;
            }
            if (y->data < n->data)
            {
                y->rightChild = n;
                size = size + 1;
            }
            else
            {
                y->leftChild = n;
                size = size + 1;
            }
            //InsertFixUp(n);
            fixViolation(n);
        }

    }
    void print(Node<T>*x)
    {

        //cout << "size: " << size << endl;
        if (x == nullptr)
        {
            return;
        }
        if (x->parent == nullptr)
        {
            cout << "(" << x->data << ")";
            cout << "[" << "kolor:";
            x->gibColor(x);
            cout << ", parent: NULL," << " l.child: ";
            if (x->leftChild == nullptr)
            {
                cout << "NIL";
            }
            else
                cout << x->leftChild->data;
            cout << ", r.child: ";
            if (x->rightChild == nullptr)
            {
                cout << "NIL";
            }
            else
                cout << x->rightChild->data;
            cout << "]";
            cout << "-root " << endl;
            //rodzic, l.dziecko, p.dziecko

        }
        else if (x->parent->leftChild == x)
        {
            //int c = x->gibColor(x);
            cout << "(" << x->data << ")";
            cout << "[" << "kolor:";
            x->gibColor(x);
            cout << ", parent: " << x->parent->data << ", l.child: ";
            if (x->leftChild == nullptr)
            {
                cout << "NIL";
            }
            else
                cout << x->leftChild->data;
            cout << ", r.child: ";
            if (x->rightChild == nullptr)
            {
                cout << "NIL" << "]" << endl;
            }
            else
                cout << x->rightChild->data << "]" << endl;
        }
        else
        {
            cout << "(" << x->data << ")";
            cout << "[" << "kolor:";
            x->gibColor(x);
            cout << ", parent: " << x->parent->data << ", l.child: ";
            if (x->leftChild == nullptr)
            {
                cout << "NIL";
            }
            else
                cout << x->leftChild->data;
            cout << ", r.child: ";
            if (x->rightChild == nullptr)
            {
                cout << "NIL" << "]" << endl;;
            }
            else
                cout << x->rightChild->data << "]" << endl;
        }    
        print(x->leftChild);
        print(x->rightChild);

    }
    void printTree()
    {

    if (root == nullptr)
    {
        cout << "Empty tree!" << endl;

    }
    else
        print(root);
    }
};
int randomInt()
{
    uniform_int_distribution<int> rozklad{ 0, 11000000 };
    default_random_engine generator{ 11 };
    function<int()> los{ bind(rozklad, generator) };
    int l = los();
    return l;
}

int main()
{
    RBTree<int>* d1 = new RBTree<int>();
    d1->addElement(55);
    d1->addElement(69);
    d1->addElement(62);
    d1->addElement(71);
    d1->addElement(70);
    d1->addElement(14);
    d1->printTree();
}
0

There are 0 best solutions below