Lost on splay tree implementation

56 Views Asked by At

I'm having trouble debugging my code. I believe the error lies within the Insert function and Rotate but I cannot figure it out.. Please check my code

#pragma once
#include <cassert>
#include <stdexcept>

class splay_tree {
public:
    struct node 
    {
        node(int key, node* left, node* right, node* parent) :
            key(key), left(left), right(right), parent(parent) { }

        int key;
        node* left;
        node* right;
        node* parent;
    };

    ~splay_tree() 
    {
        clear();
    }

    node* root() const 
    {
        return rt;
    }

    int size() const 
    {
        return size(rt);
    }

    bool empty() const
    {
        return rt == nullptr;
    }

    static void rotate(node* c, node* p) 
    {
        assert(c != nullptr and p != nullptr);
        assert(c == p->left or c == p->right);
        assert(c->parent == p);

        node* g = p->parent;

        if(g != nullptr) 
        {
            if (g->left == p)
                g->left = c;
            else
                g->right = c;
        } 
        else
            c->parent = nullptr;

        p->parent = c;

        if(c == p->left) 
        {
            p->left = c->right;
            if(c->right != nullptr)
                c->right->parent = p;
            c->right = p;
        }
        else{
            p->right = c->left;
            if(c->left != nullptr)
                c->left->parent = p;
            c->left = p;
        }

        if(g != nullptr)
        {
            if(p == g->left)
                g->left = c;
            else
                g->right = c;
        }
    }

    static node* splay(node* n) 
    {
        assert(n != nullptr);

        while (n->parent != nullptr) 
        {
            node* p = n->parent;
            node* g = p->parent;

            if(p == nullptr){
                rotate(n,p);
            }
            else if((p == g->left and n == p->left) or (p == g->right and n == p->right)){
                rotate(p, g);
                rotate(n, p);
            }
            else if((p == g->left and n == p->right) or (p == g->right and n == p ->left)){
                rotate(n, p);
                rotate(n, g);
            }
        }
        return n;
    }

    node* find(int k)
    {
        node* n = rt;
        node* last = nullptr;

        while(n != nullptr)
        {
            last = n;

            if(k < n->key)
                n = n->left;
            else if(k > n->key)
                n = n->right;
            else{
                return splay(n);  
            }
        }

        return splay(last);  
    }

    node* insert(int k)
    {
        if(rt == nullptr){
            rt = new node(k, nullptr, nullptr, nullptr);
            return rt;
        }

        node* n = rt;
        node* last = nullptr;

        while(n != nullptr)
        {
            if(k == n->key)
            {
                delete n; 
                return rt;
            }
            last = n;
            if(k < n->key)
                n = n->left;
            else if(k > n->key)
                n = n->right;
        }

        if(k < last->key)
        {
            last->left = new node(k, nullptr, nullptr, last);
            return splay(last->left);
        }
        else{
            last->right = new node(k, nullptr, nullptr, last);
            return splay(last->right);
        }
    }

    node* remove(int k)
    {
    
        node* n = find(k);

        if(n != nullptr and n->key == k)
        {
            node* leftSubtree = n->left;
            node* rightSubtree = n->right;

            if(leftSubtree != nullptr){
                leftSubtree->parent = nullptr;
            }

            if(rightSubtree != nullptr){
                rightSubtree->parent = nullptr;
            }

            delete n;

            if(leftSubtree == nullptr){
                return rightSubtree;
            }
            else if(rightSubtree == nullptr){
                return leftSubtree;
            }
            else{
                node* maxNode = findMax(leftSubtree);
                maxNode->right = rightSubtree;
                rightSubtree->parent = maxNode;

                return maxNode;
            }
        }   
        splay(n);
        return rt;
    }

    void set_root(node* n) 
    {
        rt = n;
    }

    void clear() 
    {
        clear(rt);
        rt = nullptr;
    }

private:
    void clear(node* n) 
    {
        if (n != nullptr) {
            clear(n->left);
            clear(n->right);
            delete n;
        }
    }

    int size(node* n) const 
    {
        if (n == nullptr) {
            return 0;
        }
        return 1 + size(n->left) + size(n->right);
    }

    node* findMax(node* n) const 
    {
        while (n->right != nullptr) {
            n = n->right;
        }
        return n;
    }

    node* rt = nullptr;
};



When I compile the test runner and run it this is my result: ---- Beginning tree tests ---- Testing rotation...Result of child-parent rotation (with subtrees) is incorrect: Expected:

---- Beginning tree tests ----
Testing rotation...Result of child-parent rotation (with subtrees) is incorrect:
Expected:
--- Tree structure ---
 10
 ├─(null)
 └─  2 [p = 10]
    ├─  5 [p = 2]
    │  ├─  7 [p = 5]
    │  └─  3 [p = 5]
    └─  1 [p = 2]
Actual result:
--- Tree structure ---
 10
 ├─  2 [p = 5]
 │  ├─  5 [p = 2]
 │  │  ├─  7 [p = 5]
 │  │  └─  3 [p = 5]
 │  └─  1 [p = 2]
 └─CYCLE (2)

Please help me fix this problem.

Assignment Instructions

Test Runner

0

There are 0 best solutions below