Self-Balancing Binary Search Tree Implementations

Self-Balancing Tree Structures

Self-balancing binary search trees maintain logaritmhic height during insertions and deletions. This ensures efficietn search, insertion, and deletion operations. Below are implementations for three common variants: SBT, Treap, and Splay trees.

Size Balanced Tree (SBT)

#include <iostream>
#include <cstdlib>
#define MAX_NODES 100000

struct SBTNode {
    int value;
    int size;
    SBTNode *left, *right;
};

SBTNode nodePool[MAX_NODES];
int nodeCount = 0;
SBTNode *root = nullptr;

void leftRotate(SBTNode* &x) {
    SBTNode *y = x->right;
    x->right = y->left;
    y->left = x;
    y->size = x->size;
    x->size = (x->left ? x->left->size : 0) + (x->right ? x->right->size : 0) + 1;
    x = y;
}

void rightRotate(SBTNode* &x) {
    SBTNode *y = x->left;
    x->left = y->right;
    y->right = x;
    y->size = x->size;
    x->size = (x->left ? x->left->size : 0) + (x->right ? x->right->size : 0) + 1;
    x = y;
}

void maintain(SBTNode* &t, bool flag) {
    if (!flag) {
        if (t->left && t->left->left && (!t->right || t->left->left->size > t->right->size)) {
            rightRotate(t);
        } else if (t->left && t->left->right && (!t->right || t->left->right->size > t->right->size)) {
            leftRotate(t->left);
            rightRotate(t);
        } else {
            return;
        }
    } else {
        if (t->right && t->right->right && (!t->left || t->right->right->size > t->left->size)) {
            leftRotate(t);
        } else if (t->right && t->right->left && (!t->left || t->right->left->size > t->left->size)) {
            rightRotate(t->right);
            leftRotate(t);
        } else {
            return;
        }
    }
    maintain(t->left, false);
    maintain(t->right, true);
}

void insert(SBTNode* &t, int value) {
    if (!t) {
        t = &nodePool[nodeCount++];
        t->value = value;
        t->size = 1;
        t->left = t->right = nullptr;
        return;
    }
    t->size++;
    if (value < t->value) {
        insert(t->left, value);
    } else {
        insert(t->right, value);
    }
    maintain(t, value >= t->value);
}

int remove(SBTNode* &t, int value) {
    t->size--;
    if ((value == t->value) || (value < t->value && !t->left) || (value > t->value && !t->right)) {
        int result = t->value;
        if (!t->left || !t->right) {
            t = t->left ? t->left : t->right;
        } else {
            t->value = remove(t->left, t->value + 1);
        }
        return result;
    }
    if (value < t->value) {
        return remove(t->left, value);
    } else {
        return remove(t->right, value);
    }
}

int predecessor(SBTNode* t, int value) {
    if (!t) return value;
    if (t->value >= value) {
        return predecessor(t->left, value);
    } else {
        int result = predecessor(t->right, value);
        return result == value ? t->value : result;
    }
}

int successor(SBTNode* t, int value) {
    if (!t) return value;
    if (t->value <= value) {
        return successor(t->right, value);
    } else {
        int result = successor(t->left, value);
        return result == value ? t->value : result;
    }
}

int getRank(SBTNode* t, int value) {
    if (!t) return 1;
    if (value <= t->value) {
        return getRank(t->left, value);
    } else {
        return (t->left ? t->left->size : 0) + 1 + getRank(t->right, value);
    }
}

int selectKth(SBTNode* t, int k) {
    int leftSize = t->left ? t->left->size : 0;
    if (k <= leftSize) {
        return selectKth(t->left, k);
    } else if (k == leftSize + 1) {
        return t->value;
    } else {
        return selectKth(t->right, k - leftSize - 1);
    }
}

void inorder(SBTNode* t) {
    if (!t) return;
    inorder(t->left);
    std::cout << t->value << " ";
    inorder(t->right);
}

int main() {
    int n, op, value;
    std::cin >> n;
    while (n--) {
        std::cin >> op >> value;
        switch (op) {
            case 1: insert(root, value); break;
            case 2: remove(root, value); break;
            case 3: std::cout << selectKth(root, value) << std::endl; break;
        }
    }
    inorder(root);
    return 0;
}

Treap Implementation

#include <iostream>
#include <cstdlib>
#define MAX_NODES 500000

struct TreapNode {
    int key, priority, size;
    TreapNode *left, *right;
    TreapNode(int k, int p) : key(k), priority(p), size(1), left(nullptr), right(nullptr) {}
};

TreapNode *root = nullptr;
int nodeCount = 0;

void updateSize(TreapNode* node) {
    if (node) {
        node->size = 1;
        if (node->left) node->size += node->left->size;
        if (node->right) node->size += node->right->size;
    }
}

void rotate(TreapNode* &parent, bool direction) { // direction 0: left, 1: right
    TreapNode *child = direction ? parent->right : parent->left;
    if (direction) {
        parent->right = child->left;
        child->left = parent;
    } else {
        parent->left = child->right;
        child->right = parent;
    }
    updateSize(parent);
    updateSize(child);
    parent = child;
}

void insert(TreapNode* &node, int key) {
    if (!node) {
        node = new TreapNode(key, rand());
        return;
    }
    node->size++;
    if (key < node->key) {
        insert(node->left, key);
        if (node->left->priority > node->priority) {
            rotate(node, true);
        }
    } else {
        insert(node->right, key);
        if (node->right->priority > node->priority) {
            rotate(node, false);
        }
    }
}

void remove(TreapNode* &node, int key) {
    if (!node) return;
    node->size--;
    if (key == node->key) {
        if (!node->left && !node->right) {
            delete node;
            node = nullptr;
        } else if (!node->left || (node->right && node->left && node->right->priority > node->left->priority)) {
            rotate(node, false);
            remove(node->left, key);
        } else {
            rotate(node, true);
            remove(node->right, key);
        }
    } else if (key < node->key) {
        remove(node->left, key);
    } else {
        remove(node->right, key);
    }
}

int getKth(TreapNode* node, int k) {
    int leftSize = node->left ? node->left->size : 0;
    if (k <= leftSize) {
        return getKth(node->left, k);
    } else if (k == leftSize + 1) {
        return node->key;
    } else {
        return getKth(node->right, k - leftSize - 1);
    }
}

int getRank(TreapNode* node, int key) {
    if (!node) return 0;
    if (key < node->key) {
        return getRank(node->left, key);
    } else {
        int leftSize = node->left ? node->left->size : 0;
        return leftSize + 1 + getRank(node->right, key);
    }
}

int predecessor(TreapNode* node, int key) {
    if (!node) return key;
    if (key <= node->key) {
        return predecessor(node->left, key);
    } else {
        int candidate = predecessor(node->right, key);
        return candidate == key ? node->key : candidate;
    }
}

int successor(TreapNode* node, int key) {
    if (!node) return key;
    if (key >= node->key) {
        return successor(node->right, key);
    } else {
        int candidate = successor(node->left, key);
        return candidate == key ? node->key : candidate;
    }
}

int main() {
    int n, op, value;
    std::cin >> n;
    while (n--) {
        std::cin >> op >> value;
        switch (op) {
            case 1: insert(root, value); break;
            case 2: remove(root, value); break;
            case 3: std::cout << getRank(root, value) << std::endl; break;
            case 4: std::cout << getKth(root, value) << std::endl; break;
            case 5: std::cout << predecessor(root, value) << std::endl; break;
            case 6: std::cout << successor(root, value) << std::endl; break;
        }
    }
    return 0;
}

Splay Tree Implementation

#include <iostream>
#include <cstdlib>

struct SplayNode {
    int key, count, size;
    SplayNode *left, *right, *parent;
    SplayNode(int k) : key(k), count(1), size(1), left(nullptr), right(nullptr), parent(nullptr) {}
};

SplayNode *root = nullptr;

void update(SplayNode *node) {
    if (node) {
        node->size = node->count;
        if (node->left) node->size += node->left->size;
        if (node->right) node->size += node->right->size;
    }
}

void rotate(SplayNode *child) {
    SplayNode *parent = child->parent;
    if (!parent) return;
    SplayNode *grandparent = parent->parent;
    bool isLeftChild = (parent->left == child);
    
    if (isLeftChild) {
        parent->left = child->right;
        if (child->right) child->right->parent = parent;
        child->right = parent;
    } else {
        parent->right = child->left;
        if (child->left) child->left->parent = parent;
        child->left = parent;
    }
    
    parent->parent = child;
    child->parent = grandparent;
    
    if (grandparent) {
        if (grandparent->left == parent) {
            grandparent->left = child;
        } else {
            grandparent->right = child;
        }
    }
    
    update(parent);
    update(child);
}

void splay(SplayNode *node) {
    while (node->parent) {
        SplayNode *parent = node->parent;
        SplayNode *grandparent = parent->parent;
        if (!grandparent) {
            rotate(node);
        } else if ((grandparent->left == parent) == (parent->left == node)) {
            rotate(parent);
            rotate(node);
        } else {
            rotate(node);
            rotate(node);
        }
    }
    root = node;
}

void insert(int key) {
    if (!root) {
        root = new SplayNode(key);
        return;
    }
    SplayNode *current = root, *parent = nullptr;
    while (current) {
        parent = current;
        if (key == current->key) {
            current->count++;
            update(current);
            splay(current);
            return;
        } else if (key < current->key) {
            current = current->left;
        } else {
            current = current->right;
        }
    }
    
    current = new SplayNode(key);
    current->parent = parent;
    if (key < parent->key) {
        parent->left = current;
    } else {
        parent->right = current;
    }
    splay(current);
}

SplayNode* find(int key) {
    SplayNode *current = root;
    while (current) {
        if (key == current->key) {
            break;
        } else if (key < current->key) {
            current = current->left;
        } else {
            current = current->right;
        }
    }
    if (current) splay(current);
    return current;
}

void remove(int key) {
    SplayNode *node = find(key);
    if (!node) return;
    if (node->count > 1) {
        node->count--;
        update(node);
        return;
    }
    
    if (!node->left && !node->right) {
        root = nullptr;
    } else if (!node->left) {
        root = node->right;
        root->parent = nullptr;
    } else if (!node->right) {
        root = node->left;
        root->parent = nullptr;
    } else {
        SplayNode *successor = node->right;
        while (successor->left) successor = successor->left;
        
        if (successor->parent != node) {
            successor->parent->left = successor->right;
            if (successor->right) successor->right->parent = successor->parent;
            successor->right = node->right;
            node->right->parent = successor;
        }
        
        successor->left = node->left;
        node->left->parent = successor;
        successor->parent = nullptr;
        update(successor);
        root = successor;
    }
    delete node;
}

int getRank(int key) {
    SplayNode *node = find(key);
    if (!node) return 0;
    return node->left ? node->left->size + 1 : 1;
}

int selectKth(int k) {
    SplayNode *current = root;
    while (current) {
        int leftSize = current->left ? current->left->size : 0;
        if (k <= leftSize) {
            current = current->left;
        } else if (k <= leftSize + current->count) {
            break;
        } else {
            k -= leftSize + current->count;
            current = current->right;
        }
    }
    if (current) {
        splay(current);
        return current->key;
    }
    return -1;
}

int predecessor(int key) {
    find(key);
    if (!root) return key;
    if (root->key < key) return root->key;
    SplayNode *current = root->left;
    if (!current) return key;
    while (current->right) current = current->right;
    return current ? current->key : key;
}

int successor(int key) {
    find(key);
    if (!root) return key;
    if (root->key > key) return root->key;
    SplayNode *current = root->right;
    if (!current) return key;
    while (current->left) current = current->left;
    return current ? current->key : key;
}

int main() {
    int n, op, value;
    std::cin >> n;
    while (n--) {
        std::cin >> op >> value;
        switch (op) {
            case 1: insert(value); break;
            case 2: remove(value); break;
            case 3: std::cout << getRank(value) << std::endl; break;
            case 4: std::cout << selectKth(value) << std::endl; break;
            case 5: std::cout << predecessor(value) << std::endl; break;
            case 6: std::cout << successor(value) << std::endl; break;
        }
    }
    return 0;
}

Tags: Data Structures Binary Search Tree balanced tree sbt Treap

Posted on Sat, 06 Jun 2026 18:01:49 +0000 by Shaba1